From f38723c0bd148191f580eef00016780ab04ea914 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 20 Nov 2019 19:34:22 -0800 Subject: [PATCH] updates --- test.py | 3 ++- train.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test.py b/test.py index ebe60902..05cc0862 100644 --- a/test.py +++ b/test.py @@ -47,9 +47,10 @@ def test(cfg, # Dataloader dataset = LoadImagesAndLabels(test_path, img_size, batch_size) + batch_size = min(batch_size, len(dataset)) dataloader = DataLoader(dataset, batch_size=batch_size, - num_workers=min([os.cpu_count(), batch_size, 16]), + num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]), pin_memory=True, collate_fn=dataset.collate_fn) diff --git a/train.py b/train.py index bb9562f3..3a4bc34f 100644 --- a/train.py +++ b/train.py @@ -193,9 +193,10 @@ def train(): cache_images=False if opt.prebias else opt.cache_images) # Dataloader + batch_size = min(batch_size, len(dataset)) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, - num_workers=min([os.cpu_count(), batch_size, 16]), + num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]), shuffle=not opt.rect, # Shuffle=True unless rectangular training is used pin_memory=True, collate_fn=dataset.collate_fn)