updates
This commit is contained in:
parent
2cd6805063
commit
d661fba8ae
2
train.py
2
train.py
|
@ -70,7 +70,7 @@ def train(
|
||||||
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
|
cutoff = load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
|
||||||
|
|
||||||
# Set optimizer
|
# Set optimizer
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr0, momentum=.9)
|
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=lr0, momentum=.9)
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
if torch.cuda.device_count() > 1:
|
||||||
model = nn.DataParallel(model)
|
model = nn.DataParallel(model)
|
||||||
|
|
Loading…
Reference in New Issue