diff --git a/train.py b/train.py index 938e19e5..127adad2 100644 --- a/train.py +++ b/train.py @@ -258,9 +258,6 @@ def train(): # Compute loss loss, loss_items = compute_loss(pred, targets, model) - if torch.isnan(loss): - print('WARNING: nan loss detected, ending training', loss_items) - return results # Scale loss by nominal batch_size of 64 loss *= batch_size / 64 @@ -282,8 +279,14 @@ def train(): mem = torch.cuda.memory_cached() / 1E9 if torch.cuda.is_available() else 0 # (GB) s = ('%10s' * 2 + '%10.3g' * 6) % ( '%g/%g' % (epoch, epochs - 1), '%.3gG' % mem, *mloss, len(targets), img_size) - pbar.set_description(s) # end batch ----------------------------------------------------------------------- + pbar.set_description(s) + if torch.isnan(loss): + print('WARNING: nan loss detected, ending training', loss_items) + return results + # end batch ------------------------------------------------------------------------------------------------ + + # Process epoch results final_epoch = epoch + 1 == epochs if opt.prebias: print_model_biases(model) @@ -342,7 +345,9 @@ def train(): torch.save(chkpt, wdir + 'backup%g.pt' % epoch) # Delete checkpoint - del chkpt # end epoch ------------------------------------------------------------------------------------- + del chkpt + + # end epoch ---------------------------------------------------------------------------------------------------- # Report time plot_results() # save as results.png