This commit is contained in:
Glenn Jocher 2019-02-22 16:15:20 +01:00
parent 12e605165e
commit ac22a717f1
1 changed files with 4 additions and 5 deletions

View File

@ -40,8 +40,10 @@ def train(
lr0 = 0.001
cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0
best_loss = float('inf')
if resume:
checkpoint = torch.load(latest, map_location='cpu')
checkpoint = torch.load('weights/yolov3.pt', map_location='cpu')
# Load weights to resume from
model.load_state_dict(checkpoint['model'])
@ -52,7 +54,7 @@ def train(
# Transfer learning (train only YOLO layers)
# for i, (name, p) in enumerate(model.named_parameters()):
# p.requires_grad = True if (p.shape[0] == 255) else False
# p.requires_grad = True if (p.shape[0] == 255) else False
# Set optimizer
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, model.parameters()), lr=lr0, momentum=.9)
@ -65,9 +67,6 @@ def train(
del checkpoint # current, saved
else:
start_epoch = 0
best_loss = float('inf')
# Initialize model with backbone (optional)
if cfg.endswith('yolov3.cfg'):
load_darknet_weights(model, weights + 'darknet53.conv.74')