diff --git a/train.py b/train.py index d463a650..d67a6296 100644 --- a/train.py +++ b/train.py @@ -153,10 +153,10 @@ def train( loss = model(imgs.to(device), targets, batch_report=report, var=var) loss.backward() - accumulated_batches = 4 # accumulate gradient for 4 batches before optimizing - if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1): - optimizer.step() - optimizer.zero_grad() + # accumulated_batches = 1 # accumulate gradient for 4 batches before optimizing + # if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1): + optimizer.step() + optimizer.zero_grad() # Running epoch-means of tracked metrics ui += 1