updates
This commit is contained in:
parent
54b62f5302
commit
e5dc942fee
8
train.py
8
train.py
|
@ -48,8 +48,8 @@ def train(
|
||||||
# Load weights to resume from
|
# Load weights to resume from
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
# if torch.cuda.device_count() > 1:
|
||||||
model = nn.DataParallel(model)
|
# model = nn.DataParallel(model)
|
||||||
model.to(device).train()
|
model.to(device).train()
|
||||||
|
|
||||||
# Transfer learning (train only YOLO layers)
|
# Transfer learning (train only YOLO layers)
|
||||||
|
@ -75,8 +75,8 @@ def train(
|
||||||
load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
|
load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
|
||||||
cutoff = 15
|
cutoff = 15
|
||||||
|
|
||||||
if torch.cuda.device_count() > 1:
|
# if torch.cuda.device_count() > 1:
|
||||||
model = nn.DataParallel(model)
|
# model = nn.DataParallel(model)
|
||||||
model.to(device).train()
|
model.to(device).train()
|
||||||
|
|
||||||
# Set optimizer
|
# Set optimizer
|
||||||
|
|
Loading…
Reference in New Issue