From bc0f30933a0c905b68dc1f9e88e94ef56c5050ff Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 7 Mar 2019 17:16:38 +0100 Subject: [PATCH] updates --- train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index a0dd3222..f53ddf60 100644 --- a/train.py +++ b/train.py @@ -85,6 +85,7 @@ def train( # Set scheduler # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1) + # Start training t0 = time.time() model_info(model) n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches @@ -124,11 +125,13 @@ def train( for g in optimizer.param_groups: g['lr'] = lr - # Compute loss, compute gradient, update parameters + # Compute loss loss = model(imgs.to(device), targets, var=var) + + # Compute gradient loss.backward() - # accumulate gradient for x batches before optimizing + # Accumulate gradient for x batches before optimizing if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1): optimizer.step() optimizer.zero_grad()