diff --git a/train.py b/train.py index a1a443dd..92a879a1 100644 --- a/train.py +++ b/train.py @@ -63,8 +63,7 @@ def main(opt): # Set optimizer # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters())) - optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, momentum=.9, - weight_decay=5e-4, nesterov=True) + optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters())) start_epoch = checkpoint['epoch'] + 1 if checkpoint['optimizer'] is not None: