updates
This commit is contained in:
parent
2cd6805063
commit
d661fba8ae
2
train.py
2
train.py
|
@ -70,7 +70,7 @@ def train(
|
|||
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
|
||||
|
||||
# Set optimizer
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9)
|
||||
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=lr0, momentum=.9)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
model = nn.DataParallel(model)
|
||||
|
|
Loading…
Reference in New Issue