Update train.py
This commit is contained in:
parent
c719792d6b
commit
c1c09eb3cc
7
train.py
7
train.py
|
@ -151,9 +151,8 @@ def train(
|
|||
print(s)
|
||||
|
||||
# Update best loss
|
||||
loss_per_target = rloss['loss'] / rloss['nT']
|
||||
if loss_per_target < best_loss:
|
||||
best_loss = loss_per_target
|
||||
if rloss['loss'] < best_loss:
|
||||
best_loss = rloss['loss']
|
||||
|
||||
# Save latest checkpoint
|
||||
checkpoint = {'epoch': epoch,
|
||||
|
@ -163,7 +162,7 @@ def train(
|
|||
torch.save(checkpoint, latest)
|
||||
|
||||
# Save best checkpoint
|
||||
if best_loss == loss_per_target:
|
||||
if best_loss == rloss['loss']:
|
||||
os.system('cp ' + latest + ' ' + best)
|
||||
|
||||
# Save backup weights every 5 epochs (optional)
|
||||
|
|
Loading…
Reference in New Issue