This commit is contained in:
Glenn Jocher 2019-04-02 16:33:52 +02:00
parent 658f2a4576
commit d526ce0d11
1 changed files with 4 additions and 3 deletions

View File

@ -180,15 +180,16 @@ def train(
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': optimizer.state_dict()}
torch.save(checkpoint, latest)
del checkpoint
# Save best checkpoint
if best_loss == mloss['total']:
os.system('cp ' + latest + ' ' + best)
torch.save(checkpoint, best)
# Save backup weights every 10 epochs (optional)
if epoch > 0 and epoch % 10 == 0:
os.system('cp ' + latest + ' ' + weights + 'backup%g.pt' % epoch)
torch.save(checkpoint, weights + 'backup%g.pt' % epoch)
del checkpoint
# Calculate mAP
with torch.no_grad():