updates
This commit is contained in:
parent
5b7325bd06
commit
0aff657e19
4
train.py
4
train.py
|
@ -53,7 +53,7 @@ def train(
|
||||||
|
|
||||||
if resume: # Load previously saved model
|
if resume: # Load previously saved model
|
||||||
if transfer: # Transfer learning
|
if transfer: # Transfer learning
|
||||||
chkpt = torch.load(weights + 'yolov3.pt', map_location=device)
|
chkpt = torch.load(weights + 'yolov3-spp.pt', map_location=device)
|
||||||
model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255},
|
model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255},
|
||||||
strict=False)
|
strict=False)
|
||||||
for p in model.parameters():
|
for p in model.parameters():
|
||||||
|
@ -99,9 +99,9 @@ def train(
|
||||||
sampler=sampler)
|
sampler=sampler)
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
nB = len(dataloader)
|
|
||||||
t = time.time()
|
t = time.time()
|
||||||
model_info(model)
|
model_info(model)
|
||||||
|
nB = len(dataloader)
|
||||||
n_burnin = min(round(nB / 5 + 1), 1000) # burn-in batches
|
n_burnin = min(round(nB / 5 + 1), 1000) # burn-in batches
|
||||||
for epoch in range(start_epoch, epochs):
|
for epoch in range(start_epoch, epochs):
|
||||||
model.train()
|
model.train()
|
||||||
|
|
Loading…
Reference in New Issue