This commit is contained in:
Glenn Jocher 2019-03-21 12:11:08 +02:00
parent aecf840701
commit 2856af5036
1 changed files with 3 additions and 3 deletions

View File

@ -64,9 +64,9 @@ def train(
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
model = nn.DataParallel(model) model = nn.DataParallel(model)
# # Transfer learning (train only YOLO layers) # Transfer learning (train only YOLO layers)
for i, (name, p) in enumerate(model.named_parameters()): # for i, (name, p) in enumerate(model.named_parameters()):
p.requires_grad = True if (p.shape[0] == 255) else False # p.requires_grad = True if (p.shape[0] == 255) else False
# Set scheduler # Set scheduler
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1) # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[54, 61], gamma=0.1)