This commit is contained in:
Glenn Jocher 2018-12-15 20:52:35 +01:00
parent 21ab0c76fd
commit 900851200e
1 changed files with 4 additions and 4 deletions

View File

@ -153,10 +153,10 @@ def train(
loss = model(imgs.to(device), targets, batch_report=report, var=var)
loss.backward()
accumulated_batches = 4 # accumulate gradient for 4 batches before optimizing
if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
optimizer.step()
optimizer.zero_grad()
# accumulated_batches = 1 # accumulate gradient for 4 batches before optimizing
# if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
optimizer.step()
optimizer.zero_grad()
# Running epoch-means of tracked metrics
ui += 1