updates
This commit is contained in:
parent
d79a54a4be
commit
5170cd36b0
4
train.py
4
train.py
|
@ -40,8 +40,6 @@ def train(
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = Darknet(cfg, img_size).to(device)
|
model = Darknet(cfg, img_size).to(device)
|
||||||
# for m in model.modules():
|
|
||||||
# weights_init_normal(m) # set weight distributions
|
|
||||||
|
|
||||||
# Optimizer
|
# Optimizer
|
||||||
lr0 = 0.001 # initial learning rate
|
lr0 = 0.001 # initial learning rate
|
||||||
|
@ -56,7 +54,7 @@ def train(
|
||||||
if resume: # Load previously saved model
|
if resume: # Load previously saved model
|
||||||
if transfer: # Transfer learning
|
if transfer: # Transfer learning
|
||||||
chkpt = torch.load(weights + 'yolov3.pt', map_location=device)
|
chkpt = torch.load(weights + 'yolov3.pt', map_location=device)
|
||||||
model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != nf},
|
model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255},
|
||||||
strict=False)
|
strict=False)
|
||||||
for p in model.parameters():
|
for p in model.parameters():
|
||||||
p.requires_grad = True if p.shape[0] == nf else False
|
p.requires_grad = True if p.shape[0] == nf else False
|
||||||
|
|
Loading…
Reference in New Issue