This commit is contained in:
glenn-jocher 2019-07-02 18:21:28 +02:00
parent ccf757b3ea
commit a8cf64af31
1 changed files with 7 additions and 7 deletions

View File

@ -77,7 +77,7 @@ def train(
cutoff = -1 # backbone reaches to cutoff layer cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0 start_epoch = 0
best_map = 0. best_fitness = 0.0
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255) nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
if opt.resume or opt.transfer: # Load previously saved model if opt.resume or opt.transfer: # Load previously saved model
if opt.transfer: # Transfer learning if opt.transfer: # Transfer learning
@ -94,7 +94,7 @@ def train(
start_epoch = chkpt['epoch'] + 1 start_epoch = chkpt['epoch'] + 1
if chkpt['optimizer'] is not None: if chkpt['optimizer'] is not None:
optimizer.load_state_dict(chkpt['optimizer']) optimizer.load_state_dict(chkpt['optimizer'])
best_loss = chkpt['best_loss'] best_fitness = chkpt['best_fitness']
del chkpt del chkpt
else: # Initialize model with backbone (optional) else: # Initialize model with backbone (optional)
@ -257,16 +257,16 @@ def train(
file.write(s + '%11.3g' * 5 % results + '\n') # P, R, mAP, F1, test_loss file.write(s + '%11.3g' * 5 % results + '\n') # P, R, mAP, F1, test_loss
# Update best map # Update best map
test_map = results[2] fitness = results[2]
if test_map > best_map: if fitness > best_fitness:
best_map = test_map best_fitness = fitness
# Save training results # Save training results
save = (not opt.nosave) or (epoch == epochs - 1) save = (not opt.nosave) or (epoch == epochs - 1)
if save: if save:
# Create checkpoint # Create checkpoint
chkpt = {'epoch': epoch, chkpt = {'epoch': epoch,
'best_map': best_map, 'best_fitness': best_fitness,
'model': model.module.state_dict() if type( 'model': model.module.state_dict() if type(
model) is nn.parallel.DistributedDataParallel else model.state_dict(), model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': optimizer.state_dict()} 'optimizer': optimizer.state_dict()}
@ -275,7 +275,7 @@ def train(
torch.save(chkpt, latest) torch.save(chkpt, latest)
# Save best checkpoint # Save best checkpoint
if best_loss == test_loss: if best_fitness == fitness:
torch.save(chkpt, best) torch.save(chkpt, best)
# Save backup every 10 epochs (optional) # Save backup every 10 epochs (optional)