This commit is contained in:
Glenn Jocher 2019-02-22 15:05:03 +01:00
parent 0f3018124f
commit 12e605165e
1 changed files with 2 additions and 3 deletions

View File

@ -50,10 +50,9 @@ def train(
# model = nn.DataParallel(model)
model.to(device).train()
# # Transfer learning (train only YOLO layers)
# Transfer learning (train only YOLO layers)
# for i, (name, p) in enumerate(model.named_parameters()):
# if p.shape[0] != 650: # not YOLO layer
# p.requires_grad = False
# p.requires_grad = True if (p.shape[0] == 255) else False
# Set optimizer
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=lr0, momentum=.9)