diff --git a/train.py b/train.py index a13433ca..5a229907 100644 --- a/train.py +++ b/train.py @@ -199,9 +199,11 @@ def train(): # Dataloader batch_size = min(batch_size, len(dataset)) + nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]) + print('Using %g dataloader workers' % nw) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, - num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]), + num_workers=nw, shuffle=not opt.rect, # Shuffle=True unless rectangular training is used pin_memory=True, collate_fn=dataset.collate_fn)