updates
This commit is contained in:
parent
be934ba5a5
commit
1cd907c59b
15
train.py
15
train.py
|
@ -32,9 +32,8 @@ def train(
|
|||
else:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
os.makedirs(weights, exist_ok=True)
|
||||
latest_weights_file = os.path.join(weights, 'latest.pt')
|
||||
best_weights_file = os.path.join(weights, 'best.pt')
|
||||
latest = os.path.join(weights, 'latest.pt')
|
||||
best = os.path.join(weights, 'best.pt')
|
||||
|
||||
# Configure run
|
||||
data_cfg = parse_data_cfg(data_cfg)
|
||||
|
@ -49,7 +48,7 @@ def train(
|
|||
|
||||
lr0 = 0.001
|
||||
if resume:
|
||||
checkpoint = torch.load(latest_weights_file, map_location='cpu')
|
||||
checkpoint = torch.load(latest, map_location='cpu')
|
||||
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
if torch.cuda.device_count() > 1:
|
||||
|
@ -185,18 +184,18 @@ def train(
|
|||
'best_loss': best_loss,
|
||||
'model': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict()}
|
||||
torch.save(checkpoint, latest_weights_file)
|
||||
torch.save(checkpoint, latest)
|
||||
|
||||
# Save best checkpoint
|
||||
if best_loss == loss_per_target:
|
||||
os.system('cp ' + latest_weights_file + ' ' + best_weights_file)
|
||||
os.system('cp ' + latest + ' ' + best)
|
||||
|
||||
# Save backup weights every 5 epochs
|
||||
if (epoch > 0) & (epoch % 5 == 0):
|
||||
os.system('cp ' + latest_weights_file + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch)))
|
||||
os.system('cp ' + latest + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch)))
|
||||
|
||||
# Calculate mAP
|
||||
mAP, R, P = test.test(cfg, data_cfg, weights=latest_weights_file, batch_size=batch_size, img_size=img_size)
|
||||
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
|
||||
|
||||
# Write epoch results
|
||||
with open('results.txt', 'a') as file:
|
||||
|
|
Loading…
Reference in New Issue