This commit is contained in:
Glenn Jocher 2019-06-15 01:35:55 +02:00
parent 8f609246db
commit bb3682024e
1 changed files with 3 additions and 2 deletions

View File

@ -140,11 +140,12 @@ def train(
# plt.savefig('LR.png', dpi=300) # plt.savefig('LR.png', dpi=300)
# Dataset # Dataset
rectangular_training = False
dataset = LoadImagesAndLabels(train_path, dataset = LoadImagesAndLabels(train_path,
img_size, img_size,
batch_size, batch_size,
augment=True, augment=True,
rect=False) rect=rectangular_training)
# Initialize distributed training # Initialize distributed training
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
@ -156,7 +157,7 @@ def train(
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
batch_size=batch_size, batch_size=batch_size,
num_workers=opt.num_workers, num_workers=opt.num_workers,
shuffle=True, # disable rectangular training if True shuffle=not rectangular_training, # Shuffle=True unless rectangular training is used
pin_memory=True, pin_memory=True,
collate_fn=dataset.collate_fn) collate_fn=dataset.collate_fn)