updates
This commit is contained in:
parent
be934ba5a5
commit
1cd907c59b
15
train.py
15
train.py
|
@ -32,9 +32,8 @@ def train(
|
||||||
else:
|
else:
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
os.makedirs(weights, exist_ok=True)
|
latest = os.path.join(weights, 'latest.pt')
|
||||||
latest_weights_file = os.path.join(weights, 'latest.pt')
|
best = os.path.join(weights, 'best.pt')
|
||||||
best_weights_file = os.path.join(weights, 'best.pt')
|
|
||||||
|
|
||||||
# Configure run
|
# Configure run
|
||||||
data_cfg = parse_data_cfg(data_cfg)
|
data_cfg = parse_data_cfg(data_cfg)
|
||||||
|
@ -49,7 +48,7 @@ def train(
|
||||||
|
|
||||||
lr0 = 0.001
|
lr0 = 0.001
|
||||||
if resume:
|
if resume:
|
||||||
checkpoint = torch.load(latest_weights_file, map_location='cpu')
|
checkpoint = torch.load(latest, map_location='cpu')
|
||||||
|
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
|
@ -185,18 +184,18 @@ def train(
|
||||||
'best_loss': best_loss,
|
'best_loss': best_loss,
|
||||||
'model': model.state_dict(),
|
'model': model.state_dict(),
|
||||||
'optimizer': optimizer.state_dict()}
|
'optimizer': optimizer.state_dict()}
|
||||||
torch.save(checkpoint, latest_weights_file)
|
torch.save(checkpoint, latest)
|
||||||
|
|
||||||
# Save best checkpoint
|
# Save best checkpoint
|
||||||
if best_loss == loss_per_target:
|
if best_loss == loss_per_target:
|
||||||
os.system('cp ' + latest_weights_file + ' ' + best_weights_file)
|
os.system('cp ' + latest + ' ' + best)
|
||||||
|
|
||||||
# Save backup weights every 5 epochs
|
# Save backup weights every 5 epochs
|
||||||
if (epoch > 0) & (epoch % 5 == 0):
|
if (epoch > 0) & (epoch % 5 == 0):
|
||||||
os.system('cp ' + latest_weights_file + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch)))
|
os.system('cp ' + latest + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch)))
|
||||||
|
|
||||||
# Calculate mAP
|
# Calculate mAP
|
||||||
mAP, R, P = test.test(cfg, data_cfg, weights=latest_weights_file, batch_size=batch_size, img_size=img_size)
|
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
|
||||||
|
|
||||||
# Write epoch results
|
# Write epoch results
|
||||||
with open('results.txt', 'a') as file:
|
with open('results.txt', 'a') as file:
|
||||||
|
|
Loading…
Reference in New Issue