diff --git a/train.py b/train.py index f53ddf60..6b7b6181 100644 --- a/train.py +++ b/train.py @@ -151,9 +151,8 @@ def train( print(s) # Update best loss - loss_per_target = rloss['loss'] / rloss['nT'] - if loss_per_target < best_loss: - best_loss = loss_per_target + if rloss['loss'] < best_loss: + best_loss = rloss['loss'] # Save latest checkpoint checkpoint = {'epoch': epoch, @@ -163,7 +162,7 @@ def train( torch.save(checkpoint, latest) # Save best checkpoint - if best_loss == loss_per_target: + if best_loss == rloss['loss']: os.system('cp ' + latest + ' ' + best) # Save backup weights every 5 epochs (optional)