diff --git a/train.py b/train.py index 9699cd2f..ec592567 100644 --- a/train.py +++ b/train.py @@ -107,11 +107,14 @@ def train(): chkpt = torch.load(weights, map_location=device) # load model - # if opt.transfer: - chkpt['model'] = {k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()} - model.load_state_dict(chkpt['model'], strict=False) - # else: - # model.load_state_dict(chkpt['model']) + try: + chkpt['model'] = {k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()} + model.load_state_dict(chkpt['model'], strict=False) + # model.load_state_dict(chkpt['model']) + except KeyError as e: + s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s. " \ + "See https://github.com/ultralytics/yolov3/issues/657" % (opt.weights, opt.cfg, opt.weights) + raise KeyError(s) from e # load optimizer if chkpt['optimizer'] is not None: