Update train.py
This commit is contained in:
parent
c719792d6b
commit
c1c09eb3cc
7
train.py
7
train.py
|
@ -151,9 +151,8 @@ def train(
|
||||||
print(s)
|
print(s)
|
||||||
|
|
||||||
# Update best loss
|
# Update best loss
|
||||||
loss_per_target = rloss['loss'] / rloss['nT']
|
if rloss['loss'] < best_loss:
|
||||||
if loss_per_target < best_loss:
|
best_loss = rloss['loss']
|
||||||
best_loss = loss_per_target
|
|
||||||
|
|
||||||
# Save latest checkpoint
|
# Save latest checkpoint
|
||||||
checkpoint = {'epoch': epoch,
|
checkpoint = {'epoch': epoch,
|
||||||
|
@ -163,7 +162,7 @@ def train(
|
||||||
torch.save(checkpoint, latest)
|
torch.save(checkpoint, latest)
|
||||||
|
|
||||||
# Save best checkpoint
|
# Save best checkpoint
|
||||||
if best_loss == loss_per_target:
|
if best_loss == rloss['loss']:
|
||||||
os.system('cp ' + latest + ' ' + best)
|
os.system('cp ' + latest + ' ' + best)
|
||||||
|
|
||||||
# Save backup weights every 5 epochs (optional)
|
# Save backup weights every 5 epochs (optional)
|
||||||
|
|
Loading…
Reference in New Issue