From 12e605165e5c1a640a4ad3e3b84c015e742b76c0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 22 Feb 2019 15:05:03 +0100 Subject: [PATCH] updates --- train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index f9414341..7fcdcefd 100644 --- a/train.py +++ b/train.py @@ -50,10 +50,9 @@ def train( # model = nn.DataParallel(model) model.to(device).train() - # # Transfer learning (train only YOLO layers) + # Transfer learning (train only YOLO layers) # for i, (name, p) in enumerate(model.named_parameters()): - # if p.shape[0] != 650: # not YOLO layer - # p.requires_grad = False + # p.requires_grad = True if (p.shape[0] == 255) else False # Set optimizer optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=lr0, momentum=.9)