diff --git a/train.py b/train.py index 4406dfca..12638c0a 100644 --- a/train.py +++ b/train.py @@ -43,7 +43,7 @@ def train( # Dataloader dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True) - dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) cutoff = -1 # backbone reaches to cutoff layer start_epoch = 0