Update train.py

This commit is contained in:
Glenn Jocher 2019-03-10 15:03:17 +01:00 committed by GitHub
parent c719792d6b
commit c1c09eb3cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 4 deletions

View File

@ -151,9 +151,8 @@ def train(
print(s) print(s)
# Update best loss # Update best loss
loss_per_target = rloss['loss'] / rloss['nT'] if rloss['loss'] < best_loss:
if loss_per_target < best_loss: best_loss = rloss['loss']
best_loss = loss_per_target
# Save latest checkpoint # Save latest checkpoint
checkpoint = {'epoch': epoch, checkpoint = {'epoch': epoch,
@ -163,7 +162,7 @@ def train(
torch.save(checkpoint, latest) torch.save(checkpoint, latest)
# Save best checkpoint # Save best checkpoint
if best_loss == loss_per_target: if best_loss == rloss['loss']:
os.system('cp ' + latest + ' ' + best) os.system('cp ' + latest + ' ' + best)
# Save backup weights every 5 epochs (optional) # Save backup weights every 5 epochs (optional)