This commit is contained in:
Glenn Jocher 2019-03-04 17:38:38 +01:00
parent 54b62f5302
commit e5dc942fee
1 changed files with 4 additions and 4 deletions

View File

@ -48,8 +48,8 @@ def train(
# Load weights to resume from # Load weights to resume from
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
if torch.cuda.device_count() > 1: # if torch.cuda.device_count() > 1:
model = nn.DataParallel(model) # model = nn.DataParallel(model)
model.to(device).train() model.to(device).train()
# Transfer learning (train only YOLO layers) # Transfer learning (train only YOLO layers)
@ -75,8 +75,8 @@ def train(
load_darknet_weights(model, weights + 'yolov3-tiny.conv.15') load_darknet_weights(model, weights + 'yolov3-tiny.conv.15')
cutoff = 15 cutoff = 15
if torch.cuda.device_count() > 1: # if torch.cuda.device_count() > 1:
model = nn.DataParallel(model) # model = nn.DataParallel(model)
model.to(device).train() model.to(device).train()
# Set optimizer # Set optimizer