weight_decay fix

This commit is contained in:
Glenn Jocher 2019-08-29 14:29:07 +02:00
parent 31d807e589
commit 7d9ffe6d4e
1 changed files with 10 additions and 5 deletions

View File

@ -258,9 +258,6 @@ def train():
# Compute loss
loss, loss_items = compute_loss(pred, targets, model)
if torch.isnan(loss):
print('WARNING: nan loss detected, ending training', loss_items)
return results
# Scale loss by nominal batch_size of 64
loss *= batch_size / 64
@ -282,8 +279,14 @@ def train():
mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0 # (GB)
s = ('%10s' * 2 + '%10.3g' * 6) % (
'%g/%g' % (epoch, epochs - 1), '%.3gG' % mem, *mloss, len(targets), img_size)
pbar.set_description(s) # end batch -----------------------------------------------------------------------
pbar.set_description(s)
if torch.isnan(loss):
print('WARNING: nan loss detected, ending training', loss_items)
return results
# end batch ------------------------------------------------------------------------------------------------
# Process epoch results
final_epoch = epoch + 1 == epochs
if opt.prebias:
print_model_biases(model)
@ -342,7 +345,9 @@ def train():
torch.save(chkpt, wdir + 'backup%g.pt' % epoch)
# Delete checkpoint
del chkpt # end epoch -------------------------------------------------------------------------------------
del chkpt
# end epoch ----------------------------------------------------------------------------------------------------
# Report time
plot_results() # save as results.png