This commit is contained in:
Glenn Jocher 2019-03-21 11:48:50 +02:00
parent 2cd6805063
commit d661fba8ae
1 changed files with 1 additions and 1 deletions

View File

@ -70,7 +70,7 @@ def train(
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15') cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
# Set optimizer # Set optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9) optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=lr0, momentum=.9)
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
model = nn.DataParallel(model) model = nn.DataParallel(model)