diff --git a/train.py b/train.py index aaef96c8..bd056686 100644 --- a/train.py +++ b/train.py @@ -108,11 +108,11 @@ def train(): chkpt = torch.load(weights, map_location=device) # load model - if opt.transfer: - chkpt['model'] = {k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()} - model.load_state_dict(chkpt['model'], strict=False) - else: - model.load_state_dict(chkpt['model']) + # if opt.transfer: + chkpt['model'] = {k: v for k, v in chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()} + model.load_state_dict(chkpt['model'], strict=False) + # else: + # model.load_state_dict(chkpt['model']) # load optimizer if chkpt['optimizer'] is not None: