diff --git a/train.py b/train.py index 5e9ee0e6..b791d7e4 100644 --- a/train.py +++ b/train.py @@ -36,6 +36,7 @@ def train( # Get dataloader dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True) + # dataloader = torch.utils.data.DataLoader(dataloader, batch_size=batch_size, num_workers=0) lr0 = 0.001 # initial learning rate cutoff = -1 # backbone reaches to cutoff layer @@ -81,7 +82,7 @@ def train( # Start training t0 = time.time() model_info(model) - n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches + n_burnin = min(round(len(dataloader) / 5 + 1), 1000) # burn-in batches for epoch in range(epochs): model.train() epoch += start_epoch