diff --git a/train.py b/train.py index 92a879a1..dff2dca0 100644 --- a/train.py +++ b/train.py @@ -63,7 +63,8 @@ def main(opt): # Set optimizer # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters())) - optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters())) + optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), + lr=1e-3, momentum=.9, weight_decay=5e-4) start_epoch = checkpoint['epoch'] + 1 if checkpoint['optimizer'] is not None: @@ -85,7 +86,7 @@ def main(opt): # Set optimizer # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4) - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=.9, weight_decay=5e-4, nesterov=True) + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=.9, weight_decay=5e-4) # Set scheduler # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)