diff --git a/train.py b/train.py index ad13f4c0..9175d4cf 100644 --- a/train.py +++ b/train.py @@ -104,7 +104,9 @@ def train(): # load model if opt.transfer: chkpt['model'] = {k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()} - model.load_state_dict(chkpt['model'], strict=False) + model.load_state_dict(chkpt['model'], strict=False) + else: + model.load_state_dict(chkpt['model']) # load optimizer if chkpt['optimizer'] is not None: