diff --git a/train.py b/train.py index efa71951..8b859539 100644 --- a/train.py +++ b/train.py @@ -236,7 +236,7 @@ def train(): mloss = torch.zeros(5).to(device) # mean losses pbar = tqdm(enumerate(dataloader), total=nb) # progress bar for i, (imgs, targets, paths, _) in pbar: # batch ------------------------------------------------------------- - ni = (i + nb * epoch) # number integrated batches (since train start) + ni = i + nb * epoch # number integrated batches (since train start) imgs = imgs.to(device) targets = targets.to(device)