weight_decay fix
This commit is contained in:
parent
31d807e589
commit
7d9ffe6d4e
15
train.py
15
train.py
|
@ -258,9 +258,6 @@ def train():
|
|||
|
||||
# Compute loss
|
||||
loss, loss_items = compute_loss(pred, targets, model)
|
||||
if torch.isnan(loss):
|
||||
print('WARNING: nan loss detected, ending training', loss_items)
|
||||
return results
|
||||
|
||||
# Scale loss by nominal batch_size of 64
|
||||
loss *= batch_size / 64
|
||||
|
@ -282,8 +279,14 @@ def train():
|
|||
mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
||||
s = ('%10s' * 2 + '%10.3g' * 6) % (
|
||||
'%g/%g' % (epoch, epochs - 1), '%.3gG' % mem, *mloss, len(targets), img_size)
|
||||
pbar.set_description(s) # end batch -----------------------------------------------------------------------
|
||||
pbar.set_description(s)
|
||||
if torch.isnan(loss):
|
||||
print('WARNING: nan loss detected, ending training', loss_items)
|
||||
return results
|
||||
|
||||
# end batch ------------------------------------------------------------------------------------------------
|
||||
|
||||
# Process epoch results
|
||||
final_epoch = epoch + 1 == epochs
|
||||
if opt.prebias:
|
||||
print_model_biases(model)
|
||||
|
@ -342,7 +345,9 @@ def train():
|
|||
torch.save(chkpt, wdir + 'backup%g.pt' % epoch)
|
||||
|
||||
# Delete checkpoint
|
||||
del chkpt # end epoch -------------------------------------------------------------------------------------
|
||||
del chkpt
|
||||
|
||||
# end epoch ----------------------------------------------------------------------------------------------------
|
||||
|
||||
# Report time
|
||||
plot_results() # save as results.png
|
||||
|
|
Loading…
Reference in New Issue