updates
This commit is contained in:
parent
54b62f5302
commit
e5dc942fee
8
train.py
8
train.py
|
@ -48,8 +48,8 @@ def train(
|
|||
# Load weights to resume from
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
model = nn.DataParallel(model)
|
||||
# if torch.cuda.device_count() > 1:
|
||||
# model = nn.DataParallel(model)
|
||||
model.to(device).train()
|
||||
|
||||
# Transfer learning (train only YOLO layers)
|
||||
|
@ -75,8 +75,8 @@ def train(
|
|||
load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
|
||||
cutoff = 15
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
model = nn.DataParallel(model)
|
||||
# if torch.cuda.device_count() > 1:
|
||||
# model = nn.DataParallel(model)
|
||||
model.to(device).train()
|
||||
|
||||
# Set optimizer
|
||||
|
|
Loading…
Reference in New Issue