updates
This commit is contained in:
parent
21ab0c76fd
commit
900851200e
4
train.py
4
train.py
|
@ -153,8 +153,8 @@ def train(
|
|||
loss = model(imgs.to(device), targets, batch_report=report, var=var)
|
||||
loss.backward()
|
||||
|
||||
accumulated_batches = 4 # accumulate gradient for 4 batches before optimizing
|
||||
if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
||||
# accumulated_batches = 1 # accumulate gradient for 4 batches before optimizing
|
||||
# if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
|
Loading…
Reference in New Issue