updates
This commit is contained in:
parent
ff9d343019
commit
bc0f30933a
7
train.py
7
train.py
|
@ -85,6 +85,7 @@ def train(
|
||||||
# Set scheduler
|
# Set scheduler
|
||||||
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
|
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)
|
||||||
|
|
||||||
|
# Start training
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
model_info(model)
|
model_info(model)
|
||||||
n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches
|
n_burnin = min(round(dataloader.nB / 5 + 1), 1000) # number of burn-in batches
|
||||||
|
@ -124,11 +125,13 @@ def train(
|
||||||
for g in optimizer.param_groups:
|
for g in optimizer.param_groups:
|
||||||
g['lr'] = lr
|
g['lr'] = lr
|
||||||
|
|
||||||
# Compute loss, compute gradient, update parameters
|
# Compute loss
|
||||||
loss = model(imgs.to(device), targets, var=var)
|
loss = model(imgs.to(device), targets, var=var)
|
||||||
|
|
||||||
|
# Compute gradient
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# accumulate gradient for x batches before optimizing
|
# Accumulate gradient for x batches before optimizing
|
||||||
if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
Loading…
Reference in New Issue