From f8e5338f0a1e73e4a53bb0adba92eee38d8c0179 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 30 Jun 2020 16:19:56 -0700 Subject: [PATCH] --resume epochs update --- train.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/train.py b/train.py index 705b5198..5402f707 100644 --- a/train.py +++ b/train.py @@ -116,29 +116,35 @@ def train(hyp): attempt_download(weights) if weights.endswith('.pt'): # pytorch format # possible weights are '*.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc. - chkpt = torch.load(weights, map_location=device) + ckpt = torch.load(weights, map_location=device) # load model try: - chkpt['model'] = {k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()} - model.load_state_dict(chkpt['model'], strict=False) + ckpt['model'] = {k: v for k, v in ckpt['model'].items() if model.state_dict()[k].numel() == v.numel()} + model.load_state_dict(ckpt['model'], strict=False) except KeyError as e: s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s. " \ "See https://github.com/ultralytics/yolov3/issues/657" % (opt.weights, opt.cfg, opt.weights) raise KeyError(s) from e # load optimizer - if chkpt['optimizer'] is not None: - optimizer.load_state_dict(chkpt['optimizer']) - best_fitness = chkpt['best_fitness'] + if ckpt['optimizer'] is not None: + optimizer.load_state_dict(ckpt['optimizer']) + best_fitness = ckpt['best_fitness'] # load results - if chkpt.get('training_results') is not None: + if ckpt.get('training_results') is not None: with open(results_file, 'w') as file: - file.write(chkpt['training_results']) # write results.txt + file.write(ckpt['training_results']) # write results.txt - start_epoch = chkpt['epoch'] + 1 - del chkpt + # epochs + start_epoch = ckpt['epoch'] + 1 + if epochs < start_epoch: + print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' % + (opt.weights, ckpt['epoch'], epochs)) + epochs += ckpt['epoch'] # finetune additional epochs + + del ckpt elif len(weights) > 0: # darknet format # possible weights are '*.weights', 'yolov3-tiny.conv.15', 'darknet53.conv.74' etc. @@ -349,17 +355,17 @@ def train(hyp): save = (not opt.nosave) or (final_epoch and not opt.evolve) if save: with open(results_file, 'r') as f: # create checkpoint - chkpt = {'epoch': epoch, + ckpt = {'epoch': epoch, 'best_fitness': best_fitness, 'training_results': f.read(), 'model': ema.ema.module.state_dict() if hasattr(model, 'module') else ema.ema.state_dict(), 'optimizer': None if final_epoch else optimizer.state_dict()} # Save last, best and delete - torch.save(chkpt, last) + torch.save(ckpt, last) if (best_fitness == fi) and not final_epoch: - torch.save(chkpt, best) - del chkpt + torch.save(ckpt, best) + del ckpt # end epoch ---------------------------------------------------------------------------------------------------- # end training