updates
This commit is contained in:
parent
c36f1e990b
commit
d79a54a4be
8
train.py
8
train.py
|
@ -53,12 +53,12 @@ def train(
|
||||||
yl = get_yolo_layers(model) # yolo layers
|
yl = get_yolo_layers(model) # yolo layers
|
||||||
nf = int(model.module_defs[yl[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
nf = int(model.module_defs[yl[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
||||||
|
|
||||||
if resume: # Load previously saved PyTorch model
|
if resume: # Load previously saved model
|
||||||
if transfer: # Transfer learning
|
if transfer: # Transfer learning
|
||||||
chkpt = torch.load(weights + 'yolov3.pt', map_location=device)
|
chkpt = torch.load(weights + 'yolov3.pt', map_location=device)
|
||||||
model.load_state_dict(
|
model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != nf},
|
||||||
{k: v for k, v in chkpt['model'].items() if (int(k.split('.')[1]) + 1) not in yl}, strict=False)
|
strict=False)
|
||||||
for (name, p) in model.named_parameters():
|
for p in model.parameters():
|
||||||
p.requires_grad = True if p.shape[0] == nf else False
|
p.requires_grad = True if p.shape[0] == nf else False
|
||||||
|
|
||||||
else: # resume from latest.pt
|
else: # resume from latest.pt
|
||||||
|
|
Loading…
Reference in New Issue