This commit is contained in:
Glenn Jocher 2019-11-20 19:34:22 -08:00
parent a0067ac8fb
commit f38723c0bd
2 changed files with 4 additions and 2 deletions

View File

@ -47,9 +47,10 @@ def test(cfg,
# Dataloader # Dataloader
dataset = LoadImagesAndLabels(test_path, img_size, batch_size) dataset = LoadImagesAndLabels(test_path, img_size, batch_size)
batch_size = min(batch_size, len(dataset))
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
batch_size=batch_size, batch_size=batch_size,
num_workers=min([os.cpu_count(), batch_size, 16]), num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]),
pin_memory=True, pin_memory=True,
collate_fn=dataset.collate_fn) collate_fn=dataset.collate_fn)

View File

@ -193,9 +193,10 @@ def train():
cache_images=False if opt.prebias else opt.cache_images) cache_images=False if opt.prebias else opt.cache_images)
# Dataloader # Dataloader
batch_size = min(batch_size, len(dataset))
dataloader = torch.utils.data.DataLoader(dataset, dataloader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size, batch_size=batch_size,
num_workers=min([os.cpu_count(), batch_size, 16]), num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]),
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
pin_memory=True, pin_memory=True,
collate_fn=dataset.collate_fn) collate_fn=dataset.collate_fn)