From 900851200e8132c713f6436ae1aa95b470e8abf7 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 15 Dec 2018 20:52:35 +0100 Subject: [PATCH] updates --- train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index d463a650..d67a6296 100644 --- a/train.py +++ b/train.py @@ -153,10 +153,10 @@ def train( loss = model(imgs.to(device), targets, batch_report=report, var=var) loss.backward() - accumulated_batches = 4 # accumulate gradient for 4 batches before optimizing - if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1): - optimizer.step() - optimizer.zero_grad() + # accumulated_batches = 1 # accumulate gradient for 4 batches before optimizing + # if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1): + optimizer.step() + optimizer.zero_grad() # Running epoch-means of tracked metrics ui += 1