diff --git a/train.py b/train.py index 514195c2..626bbcc0 100644 --- a/train.py +++ b/train.py @@ -68,7 +68,7 @@ def main(opt): # optimizer = torch.optim.SGD(model.parameters(), lr=.001, momentum=.9, weight_decay=5e-4, nesterov=True) # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters())) optimizer = torch.optim.Adam(model.parameters()) - #optimizer.load_state_dict(checkpoint['optimizer']) + optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] + 1 best_loss = checkpoint['best_loss']