updates
This commit is contained in:
parent
21ab0c76fd
commit
900851200e
8
train.py
8
train.py
|
@ -153,10 +153,10 @@ def train(
|
||||||
loss = model(imgs.to(device), targets, batch_report=report, var=var)
|
loss = model(imgs.to(device), targets, batch_report=report, var=var)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
accumulated_batches = 4 # accumulate gradient for 4 batches before optimizing
|
# accumulated_batches = 1 # accumulate gradient for 4 batches before optimizing
|
||||||
if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
# if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Running epoch-means of tracked metrics
|
# Running epoch-means of tracked metrics
|
||||||
ui += 1
|
ui += 1
|
||||||
|
|
Loading…
Reference in New Issue