--resume epochs update

This commit is contained in:
Glenn Jocher 2020-06-30 16:19:56 -07:00
parent 7f953b2106
commit f8e5338f0a
1 changed files with 20 additions and 14 deletions

View File

@ -116,29 +116,35 @@ def train(hyp):
attempt_download(weights)
if weights.endswith('.pt'): # pytorch format
# possible weights are '*.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc.
chkpt = torch.load(weights, map_location=device)
ckpt = torch.load(weights, map_location=device)
# load model
try:
chkpt['model'] = {k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()}
model.load_state_dict(chkpt['model'], strict=False)
ckpt['model'] = {k: v for k, v in ckpt['model'].items() if model.state_dict()[k].numel() == v.numel()}
model.load_state_dict(ckpt['model'], strict=False)
except KeyError as e:
s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s. " \
"See https://github.com/ultralytics/yolov3/issues/657" % (opt.weights, opt.cfg, opt.weights)
raise KeyError(s) from e
# load optimizer
if chkpt['optimizer'] is not None:
optimizer.load_state_dict(chkpt['optimizer'])
best_fitness = chkpt['best_fitness']
if ckpt['optimizer'] is not None:
optimizer.load_state_dict(ckpt['optimizer'])
best_fitness = ckpt['best_fitness']
# load results
if chkpt.get('training_results') is not None:
if ckpt.get('training_results') is not None:
with open(results_file, 'w') as file:
file.write(chkpt['training_results']) # write results.txt
file.write(ckpt['training_results']) # write results.txt
start_epoch = chkpt['epoch'] + 1
del chkpt
# epochs
start_epoch = ckpt['epoch'] + 1
if epochs < start_epoch:
print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
(opt.weights, ckpt['epoch'], epochs))
epochs += ckpt['epoch'] # finetune additional epochs
del ckpt
elif len(weights) > 0: # darknet format
# possible weights are '*.weights', 'yolov3-tiny.conv.15', 'darknet53.conv.74' etc.
@ -349,17 +355,17 @@ def train(hyp):
save = (not opt.nosave) or (final_epoch and not opt.evolve)
if save:
with open(results_file, 'r') as f: # create checkpoint
chkpt = {'epoch': epoch,
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': f.read(),
'model': ema.ema.module.state_dict() if hasattr(model, 'module') else ema.ema.state_dict(),
'optimizer': None if final_epoch else optimizer.state_dict()}
# Save last, best and delete
torch.save(chkpt, last)
torch.save(ckpt, last)
if (best_fitness == fi) and not final_epoch:
torch.save(chkpt, best)
del chkpt
torch.save(ckpt, best)
del ckpt
# end epoch ----------------------------------------------------------------------------------------------------
# end training