diff --git a/train.py b/train.py index 70a9354e..54fbb8b5 100644 --- a/train.py +++ b/train.py @@ -140,11 +140,12 @@ def train( # plt.savefig('LR.png', dpi=300) # Dataset + rectangular_training = False dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True, - rect=False) + rect=rectangular_training) # Initialize distributed training if torch.cuda.device_count() > 1: @@ -156,7 +157,7 @@ def train( dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=opt.num_workers, - shuffle=True, # disable rectangular training if True + shuffle=not rectangular_training, # Shuffle=True unless rectangular training is used pin_memory=True, collate_fn=dataset.collate_fn)