updates
This commit is contained in:
parent
6345a1d218
commit
8610026e2c
17
train.py
17
train.py
|
@ -372,6 +372,15 @@ def train():
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def prebias():
|
||||||
|
# trains output bias layers for 1 epoch and creates new backbone
|
||||||
|
if opt.prebias:
|
||||||
|
train() # transfer-learn yolo biases for 1 epoch
|
||||||
|
create_backbone(last) # saved results as backbone.pt
|
||||||
|
opt.weights = wdir + 'backbone.pt' # assign backbone
|
||||||
|
opt.prebias = False # disable prebias
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--epochs', type=int, default=273) # 500200 batches at bs 16, 117263 images = 273 epochs
|
parser.add_argument('--epochs', type=int, default=273) # 500200 batches at bs 16, 117263 images = 273 epochs
|
||||||
|
@ -403,12 +412,6 @@ if __name__ == '__main__':
|
||||||
device = torch_utils.select_device(opt.device, apex=mixed_precision)
|
device = torch_utils.select_device(opt.device, apex=mixed_precision)
|
||||||
|
|
||||||
tb_writer = None
|
tb_writer = None
|
||||||
if opt.prebias:
|
|
||||||
train() # transfer-learn yolo biases for 1 epoch
|
|
||||||
create_backbone(last) # saved results as backbone.pt
|
|
||||||
opt.weights = wdir + 'backbone.pt' # assign backbone
|
|
||||||
opt.prebias = False # disable prebias
|
|
||||||
|
|
||||||
if not opt.evolve: # Train normally
|
if not opt.evolve: # Train normally
|
||||||
try:
|
try:
|
||||||
# Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/
|
# Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/
|
||||||
|
@ -418,6 +421,7 @@ if __name__ == '__main__':
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
prebias() # optional
|
||||||
train() # train normally
|
train() # train normally
|
||||||
|
|
||||||
else: # Evolve hyperparameters (optional)
|
else: # Evolve hyperparameters (optional)
|
||||||
|
@ -455,6 +459,7 @@ if __name__ == '__main__':
|
||||||
hyp[k] = np.clip(hyp[k], v[0], v[1])
|
hyp[k] = np.clip(hyp[k], v[0], v[1])
|
||||||
|
|
||||||
# Train mutation
|
# Train mutation
|
||||||
|
prebias()
|
||||||
results = train()
|
results = train()
|
||||||
|
|
||||||
# Write mutation results
|
# Write mutation results
|
||||||
|
|
Loading…
Reference in New Issue