updates
This commit is contained in:
parent
59b1a1e89b
commit
da9ec7d12f
23
train.py
23
train.py
|
@ -61,12 +61,13 @@ def train(
|
||||||
cutoff = -1 # backbone reaches to cutoff layer
|
cutoff = -1 # backbone reaches to cutoff layer
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
best_fitness = 0.0
|
best_fitness = 0.0
|
||||||
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
|
||||||
if opt.resume or opt.transfer: # Load previously saved model
|
if opt.resume or opt.transfer: # Load previously saved model
|
||||||
if opt.transfer: # Transfer learning
|
if opt.transfer: # Transfer learning
|
||||||
|
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
|
||||||
chkpt = torch.load(weights + 'yolov3-spp.pt', map_location=device)
|
chkpt = torch.load(weights + 'yolov3-spp.pt', map_location=device)
|
||||||
model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255},
|
model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255},
|
||||||
strict=False)
|
strict=False)
|
||||||
|
|
||||||
for p in model.parameters():
|
for p in model.parameters():
|
||||||
p.requires_grad = True if p.shape[0] == nf else False
|
p.requires_grad = True if p.shape[0] == nf else False
|
||||||
|
|
||||||
|
@ -74,10 +75,14 @@ def train(
|
||||||
chkpt = torch.load(latest, map_location=device) # load checkpoint
|
chkpt = torch.load(latest, map_location=device) # load checkpoint
|
||||||
model.load_state_dict(chkpt['model'])
|
model.load_state_dict(chkpt['model'])
|
||||||
|
|
||||||
start_epoch = chkpt['epoch'] + 1
|
|
||||||
if chkpt['optimizer'] is not None:
|
if chkpt['optimizer'] is not None:
|
||||||
optimizer.load_state_dict(chkpt['optimizer'])
|
optimizer.load_state_dict(chkpt['optimizer'])
|
||||||
best_fitness = chkpt['best_fitness']
|
best_fitness = chkpt['best_fitness']
|
||||||
|
|
||||||
|
with open('results.txt', 'w') as file:
|
||||||
|
file.write(chkpt['training_results']) # write results.txt
|
||||||
|
|
||||||
|
start_epoch = chkpt['epoch'] + 1
|
||||||
del chkpt
|
del chkpt
|
||||||
|
|
||||||
else: # Initialize model with backbone (optional)
|
else: # Initialize model with backbone (optional)
|
||||||
|
@ -246,12 +251,14 @@ def train(
|
||||||
# Save training results
|
# Save training results
|
||||||
save = (not opt.nosave) or (epoch == epochs - 1)
|
save = (not opt.nosave) or (epoch == epochs - 1)
|
||||||
if save:
|
if save:
|
||||||
# Create checkpoint
|
with open('results.txt', 'r') as file:
|
||||||
chkpt = {'epoch': epoch,
|
# Create checkpoint
|
||||||
'best_fitness': best_fitness,
|
chkpt = {'epoch': epoch,
|
||||||
'model': model.module.state_dict() if type(
|
'best_fitness': best_fitness,
|
||||||
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
'training_results': file.read(),
|
||||||
'optimizer': optimizer.state_dict()}
|
'model': model.module.state_dict() if type(
|
||||||
|
model) is nn.parallel.DistributedDataParallel else model.state_dict(),
|
||||||
|
'optimizer': optimizer.state_dict()}
|
||||||
|
|
||||||
# Save latest checkpoint
|
# Save latest checkpoint
|
||||||
torch.save(chkpt, latest)
|
torch.save(chkpt, latest)
|
||||||
|
|
Loading…
Reference in New Issue