This commit is contained in:
Glenn Jocher 2019-02-09 19:29:19 +01:00
parent be934ba5a5
commit 1cd907c59b
1 changed files with 7 additions and 8 deletions

View File

@ -32,9 +32,8 @@ def train(
else:
torch.backends.cudnn.benchmark = True
os.makedirs(weights, exist_ok=True)
latest_weights_file = os.path.join(weights, 'latest.pt')
best_weights_file = os.path.join(weights, 'best.pt')
latest = os.path.join(weights, 'latest.pt')
best = os.path.join(weights, 'best.pt')
# Configure run
data_cfg = parse_data_cfg(data_cfg)
@ -49,7 +48,7 @@ def train(
lr0 = 0.001
if resume:
checkpoint = torch.load(latest_weights_file, map_location='cpu')
checkpoint = torch.load(latest, map_location='cpu')
model.load_state_dict(checkpoint['model'])
if torch.cuda.device_count() > 1:
@ -185,18 +184,18 @@ def train(
'best_loss': best_loss,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()}
torch.save(checkpoint, latest_weights_file)
torch.save(checkpoint, latest)
# Save best checkpoint
if best_loss == loss_per_target:
os.system('cp ' + latest_weights_file + ' ' + best_weights_file)
os.system('cp ' + latest + ' ' + best)
# Save backup weights every 5 epochs
if (epoch > 0) & (epoch % 5 == 0):
os.system('cp ' + latest_weights_file + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch)))
os.system('cp ' + latest + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch)))
# Calculate mAP
mAP, R, P = test.test(cfg, data_cfg, weights=latest_weights_file, batch_size=batch_size, img_size=img_size)
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
# Write epoch results
with open('results.txt', 'a') as file: