updates
This commit is contained in:
parent
8f609246db
commit
bb3682024e
5
train.py
5
train.py
|
@ -140,11 +140,12 @@ def train(
|
|||
# plt.savefig('LR.png', dpi=300)
|
||||
|
||||
# Dataset
|
||||
rectangular_training = False
|
||||
dataset = LoadImagesAndLabels(train_path,
|
||||
img_size,
|
||||
batch_size,
|
||||
augment=True,
|
||||
rect=False)
|
||||
rect=rectangular_training)
|
||||
|
||||
# Initialize distributed training
|
||||
if torch.cuda.device_count() > 1:
|
||||
|
@ -156,7 +157,7 @@ def train(
|
|||
dataloader = DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=opt.num_workers,
|
||||
shuffle=True, # disable rectangular training if True
|
||||
shuffle=not rectangular_training, # Shuffle=True unless rectangular training is used
|
||||
pin_memory=True,
|
||||
collate_fn=dataset.collate_fn)
|
||||
|
||||
|
|
Loading…
Reference in New Issue