updates
This commit is contained in:
parent
ff9d343019
commit
bc0f30933a
7
train.py
7
train.py
|
@ -85,6 +85,7 @@ def train(
|
|||
# Set scheduler
|
||||
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
|
||||
|
||||
# Start training
|
||||
t0 = time.time()
|
||||
model_info(model)
|
||||
n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches
|
||||
|
@ -124,11 +125,13 @@ def train(
|
|||
for g in optimizer.param_groups:
|
||||
g['lr'] = lr
|
||||
|
||||
# Compute loss, compute gradient, update parameters
|
||||
# Compute loss
|
||||
loss = model(imgs.to(device), targets, var=var)
|
||||
|
||||
# Compute gradient
|
||||
loss.backward()
|
||||
|
||||
# accumulate gradient for x batches before optimizing
|
||||
# Accumulate gradient for x batches before optimizing
|
||||
if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
|
Loading…
Reference in New Issue