This commit is contained in:
Glenn Jocher 2019-10-05 12:45:10 +02:00
parent 6345a1d218
commit 8610026e2c
1 changed files with 11 additions and 6 deletions

View File

@ -372,6 +372,15 @@ def train():
return results return results
def prebias():
# trains output bias layers for 1 epoch and creates new backbone
if opt.prebias:
train() # transfer-learn yolo biases for 1 epoch
create_backbone(last) # saved results as backbone.pt
opt.weights = wdir + 'backbone.pt' # assign backbone
opt.prebias = False # disable prebias
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=273) # 500200 batches at bs 16, 117263 images = 273 epochs parser.add_argument('--epochs', type=int, default=273) # 500200 batches at bs 16, 117263 images = 273 epochs
@ -403,12 +412,6 @@ if __name__ == '__main__':
device = torch_utils.select_device(opt.device, apex=mixed_precision) device = torch_utils.select_device(opt.device, apex=mixed_precision)
tb_writer = None tb_writer = None
if opt.prebias:
train() # transfer-learn yolo biases for 1 epoch
create_backbone(last) # saved results as backbone.pt
opt.weights = wdir + 'backbone.pt' # assign backbone
opt.prebias = False # disable prebias
if not opt.evolve: # Train normally if not opt.evolve: # Train normally
try: try:
# Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/ # Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/
@ -418,6 +421,7 @@ if __name__ == '__main__':
except: except:
pass pass
prebias() # optional
train() # train normally train() # train normally
else: # Evolve hyperparameters (optional) else: # Evolve hyperparameters (optional)
@ -455,6 +459,7 @@ if __name__ == '__main__':
hyp[k] = np.clip(hyp[k], v[0], v[1]) hyp[k] = np.clip(hyp[k], v[0], v[1])
# Train mutation # Train mutation
prebias()
results = train() results = train()
# Write mutation results # Write mutation results