This commit is contained in:
Glenn Jocher 2019-03-07 17:16:38 +01:00
parent ff9d343019
commit bc0f30933a
1 changed files with 5 additions and 2 deletions

View File

@ -85,6 +85,7 @@ def train(
# Set scheduler # Set scheduler
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1) # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
# Start training
t0 = time.time() t0 = time.time()
model_info(model) model_info(model)
n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches
@ -124,11 +125,13 @@ def train(
for g in optimizer.param_groups: for g in optimizer.param_groups:
g['lr'] = lr g['lr'] = lr
# Compute loss, compute gradient, update parameters # Compute loss
loss = model(imgs.to(device), targets, var=var) loss = model(imgs.to(device), targets, var=var)
# Compute gradient
loss.backward() loss.backward()
# accumulate gradient for x batches before optimizing # Accumulate gradient for x batches before optimizing
if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1): if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()