This commit is contained in:
Glenn Jocher 2019-07-08 18:00:19 +02:00
parent 59b1a1e89b
commit da9ec7d12f
1 changed files with 15 additions and 8 deletions

View File

@ -61,12 +61,13 @@ def train(
cutoff = -1 # backbone reaches to cutoff layer cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0 start_epoch = 0
best_fitness = 0.0 best_fitness = 0.0
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
if opt.resume or opt.transfer: # Load previously saved model if opt.resume or opt.transfer: # Load previously saved model
if opt.transfer: # Transfer learning if opt.transfer: # Transfer learning
nf = int(model.module_defs[model.yolo_layers[0] - 1]['filters']) # yolo layer size (i.e. 255)
chkpt = torch.load(weights + 'yolov3-spp.pt', map_location=device) chkpt = torch.load(weights + 'yolov3-spp.pt', map_location=device)
model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255}, model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != 255},
strict=False) strict=False)
for p in model.parameters(): for p in model.parameters():
p.requires_grad = True if p.shape[0] == nf else False p.requires_grad = True if p.shape[0] == nf else False
@ -74,10 +75,14 @@ def train(
chkpt = torch.load(latest, map_location=device) # load checkpoint chkpt = torch.load(latest, map_location=device) # load checkpoint
model.load_state_dict(chkpt['model']) model.load_state_dict(chkpt['model'])
start_epoch = chkpt['epoch'] + 1
if chkpt['optimizer'] is not None: if chkpt['optimizer'] is not None:
optimizer.load_state_dict(chkpt['optimizer']) optimizer.load_state_dict(chkpt['optimizer'])
best_fitness = chkpt['best_fitness'] best_fitness = chkpt['best_fitness']
with open('results.txt', 'w') as file:
file.write(chkpt['training_results']) # write results.txt
start_epoch = chkpt['epoch'] + 1
del chkpt del chkpt
else: # Initialize model with backbone (optional) else: # Initialize model with backbone (optional)
@ -246,9 +251,11 @@ def train(
# Save training results # Save training results
save = (not opt.nosave) or (epoch == epochs - 1) save = (not opt.nosave) or (epoch == epochs - 1)
if save: if save:
with open('results.txt', 'r') as file:
# Create checkpoint # Create checkpoint
chkpt = {'epoch': epoch, chkpt = {'epoch': epoch,
'best_fitness': best_fitness, 'best_fitness': best_fitness,
'training_results': file.read(),
'model': model.module.state_dict() if type( 'model': model.module.state_dict() if type(
model) is nn.parallel.DistributedDataParallel else model.state_dict(), model) is nn.parallel.DistributedDataParallel else model.state_dict(),
'optimizer': optimizer.state_dict()} 'optimizer': optimizer.state_dict()}