updates
This commit is contained in:
		
							parent
							
								
									6345a1d218
								
							
						
					
					
						commit
						8610026e2c
					
				
							
								
								
									
										17
									
								
								train.py
								
								
								
								
							
							
						
						
									
										17
									
								
								train.py
								
								
								
								
							|  | @ -372,6 +372,15 @@ def train(): | ||||||
|     return results |     return results | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def prebias(): | ||||||
|  |     # trains output bias layers for 1 epoch and creates new backbone | ||||||
|  |     if opt.prebias: | ||||||
|  |         train()  # transfer-learn yolo biases for 1 epoch | ||||||
|  |         create_backbone(last)  # saved results as backbone.pt | ||||||
|  |         opt.weights = wdir + 'backbone.pt'  # assign backbone | ||||||
|  |         opt.prebias = False  # disable prebias | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     parser = argparse.ArgumentParser() |     parser = argparse.ArgumentParser() | ||||||
|     parser.add_argument('--epochs', type=int, default=273)  # 500200 batches at bs 16, 117263 images = 273 epochs |     parser.add_argument('--epochs', type=int, default=273)  # 500200 batches at bs 16, 117263 images = 273 epochs | ||||||
|  | @ -403,12 +412,6 @@ if __name__ == '__main__': | ||||||
|     device = torch_utils.select_device(opt.device, apex=mixed_precision) |     device = torch_utils.select_device(opt.device, apex=mixed_precision) | ||||||
| 
 | 
 | ||||||
|     tb_writer = None |     tb_writer = None | ||||||
|     if opt.prebias: |  | ||||||
|         train()  # transfer-learn yolo biases for 1 epoch |  | ||||||
|         create_backbone(last)  # saved results as backbone.pt |  | ||||||
|         opt.weights = wdir + 'backbone.pt'  # assign backbone |  | ||||||
|         opt.prebias = False  # disable prebias |  | ||||||
| 
 |  | ||||||
|     if not opt.evolve:  # Train normally |     if not opt.evolve:  # Train normally | ||||||
|         try: |         try: | ||||||
|             # Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/ |             # Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/ | ||||||
|  | @ -418,6 +421,7 @@ if __name__ == '__main__': | ||||||
|         except: |         except: | ||||||
|             pass |             pass | ||||||
| 
 | 
 | ||||||
|  |         prebias()  # optional | ||||||
|         train()  # train normally |         train()  # train normally | ||||||
| 
 | 
 | ||||||
|     else:  # Evolve hyperparameters (optional) |     else:  # Evolve hyperparameters (optional) | ||||||
|  | @ -455,6 +459,7 @@ if __name__ == '__main__': | ||||||
|                 hyp[k] = np.clip(hyp[k], v[0], v[1]) |                 hyp[k] = np.clip(hyp[k], v[0], v[1]) | ||||||
| 
 | 
 | ||||||
|             # Train mutation |             # Train mutation | ||||||
|  |             prebias() | ||||||
|             results = train() |             results = train() | ||||||
| 
 | 
 | ||||||
|             # Write mutation results |             # Write mutation results | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue