train.py remove hardcoded weights/ path for weights.
If I want to store my weights in 'weights2' path: python train.py --weights-path weights2 Default is the same: weights
This commit is contained in:
		
							parent
							
								
									9c0c1f23ab
								
							
						
					
					
						commit
						868a116750
					
				
							
								
								
									
										42
									
								
								train.py
								
								
								
								
							
							
						
						
									
										42
									
								
								train.py
								
								
								
								
							|  | @ -11,6 +11,11 @@ from utils import torch_utils | ||||||
| # Import test.py to get mAP after each epoch | # Import test.py to get mAP after each epoch | ||||||
| import test | import test | ||||||
| 
 | 
 | ||||||
|  | DARKNET_WEIGHTS_FILENAME = 'darknet53.conv.74' | ||||||
|  | DARKNET_WEIGHTS_URL = 'https://pjreddie.com/media/files/{}'.format( | ||||||
|  |     DARKNET_WEIGHTS_FILENAME | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| def train( | def train( | ||||||
|     net_config_path, |     net_config_path, | ||||||
|  | @ -19,6 +24,7 @@ def train( | ||||||
|     resume=False, |     resume=False, | ||||||
|     epochs=100, |     epochs=100, | ||||||
|     batch_size=16, |     batch_size=16, | ||||||
|  |     weights_path='weights', | ||||||
|     report=False, |     report=False, | ||||||
|     multi_scale=False, |     multi_scale=False, | ||||||
|     freeze_backbone=True, |     freeze_backbone=True, | ||||||
|  | @ -31,12 +37,14 @@ def train( | ||||||
|     if not multi_scale: |     if not multi_scale: | ||||||
|         torch.backends.cudnn.benchmark = True |         torch.backends.cudnn.benchmark = True | ||||||
| 
 | 
 | ||||||
|     os.makedirs('weights', exist_ok=True) |     os.makedirs(weights_path, exist_ok=True) | ||||||
|  |     latest_weights_file = os.path.join(weights_path, 'latest.pt') | ||||||
|  |     best_weights_file = os.path.join(weights_path, 'best.pt') | ||||||
| 
 | 
 | ||||||
|     # Configure run |     # Configure run | ||||||
|     data_config = parse_data_config(data_config_path) |     data_config = parse_data_config(data_config_path) | ||||||
|     num_classes = int(data_config['classes']) |     num_classes = int(data_config['classes']) | ||||||
|     train_path = '../coco/trainvalno5k.txt' |     train_path = data_config['train'] | ||||||
| 
 | 
 | ||||||
|     # Initialize model |     # Initialize model | ||||||
|     model = Darknet(net_config_path, img_size) |     model = Darknet(net_config_path, img_size) | ||||||
|  | @ -50,7 +58,7 @@ def train( | ||||||
| 
 | 
 | ||||||
|     lr0 = 0.001 |     lr0 = 0.001 | ||||||
|     if resume: |     if resume: | ||||||
|         checkpoint = torch.load('weights/latest.pt', map_location='cpu') |         checkpoint = torch.load(latest_weights_file, map_location='cpu') | ||||||
| 
 | 
 | ||||||
|         model.load_state_dict(checkpoint['model']) |         model.load_state_dict(checkpoint['model']) | ||||||
|         if torch.cuda.device_count() > 1: |         if torch.cuda.device_count() > 1: | ||||||
|  | @ -79,9 +87,13 @@ def train( | ||||||
|         best_loss = float('inf') |         best_loss = float('inf') | ||||||
| 
 | 
 | ||||||
|         # Initialize model with darknet53 weights (optional) |         # Initialize model with darknet53 weights (optional) | ||||||
|         if not os.path.isfile('weights/darknet53.conv.74'): |         def_weight_file = os.path.join(weights_path, DARKNET_WEIGHTS_FILENAME) | ||||||
|             os.system('wget https://pjreddie.com/media/files/darknet53.conv.74 -P weights') |         if not os.path.isfile(def_weight_file): | ||||||
|         load_weights(model, 'weights/darknet53.conv.74') |             os.system('wget {} -P {}'.format( | ||||||
|  |                 DARKNET_WEIGHTS_URL, | ||||||
|  |                 weights_path)) | ||||||
|  |         assert os.path.isfile(def_weight_file) | ||||||
|  |         load_weights(model, def_weight_file) | ||||||
| 
 | 
 | ||||||
|         if torch.cuda.device_count() > 1: |         if torch.cuda.device_count() > 1: | ||||||
|             raise Exception('Multi-GPU not currently supported: https://github.com/ultralytics/yolov3/issues/21') |             raise Exception('Multi-GPU not currently supported: https://github.com/ultralytics/yolov3/issues/21') | ||||||
|  | @ -187,21 +199,29 @@ def train( | ||||||
|                       'best_loss': best_loss, |                       'best_loss': best_loss, | ||||||
|                       'model': model.state_dict(), |                       'model': model.state_dict(), | ||||||
|                       'optimizer': optimizer.state_dict()} |                       'optimizer': optimizer.state_dict()} | ||||||
|         torch.save(checkpoint, 'weights/latest.pt') |         torch.save(checkpoint, latest_weights_file) | ||||||
| 
 | 
 | ||||||
|         # Save best checkpoint |         # Save best checkpoint | ||||||
|         if best_loss == loss_per_target: |         if best_loss == loss_per_target: | ||||||
|             os.system('cp weights/latest.pt weights/best.pt') |             os.system('cp {} {}'.format( | ||||||
|  |                 latest_weights_file, | ||||||
|  |                 best_weights_file, | ||||||
|  |             )) | ||||||
| 
 | 
 | ||||||
|         # Save backup weights every 5 epochs |         # Save backup weights every 5 epochs | ||||||
|         if (epoch > 0) & (epoch % 5 == 0): |         if (epoch > 0) & (epoch % 5 == 0): | ||||||
|             os.system('cp weights/latest.pt weights/backup' + str(epoch) + '.pt') |             backup_file_name = 'backup{}.pt'.format(epoch) | ||||||
|  |             backup_file_path = os.path.join(weights_path, backup_file_name) | ||||||
|  |             os.system('cp {} {}'.format( | ||||||
|  |                 latest_weights_file, | ||||||
|  |                 backup_file_path, | ||||||
|  |             )) | ||||||
| 
 | 
 | ||||||
|         # Calculate mAP |         # Calculate mAP | ||||||
|         mAP, R, P = test.test( |         mAP, R, P = test.test( | ||||||
|             net_config_path, |             net_config_path, | ||||||
|             data_config_path, |             data_config_path, | ||||||
|             'weights/latest.pt', |             latest_weights_file, | ||||||
|             batch_size=batch_size, |             batch_size=batch_size, | ||||||
|             img_size=img_size, |             img_size=img_size, | ||||||
|         ) |         ) | ||||||
|  | @ -224,6 +244,7 @@ if __name__ == '__main__': | ||||||
|     parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path') |     parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path') | ||||||
|     parser.add_argument('--multi-scale', default=False, help='random image sizes per batch 320 - 608') |     parser.add_argument('--multi-scale', default=False, help='random image sizes per batch 320 - 608') | ||||||
|     parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels') |     parser.add_argument('--img-size', type=int, default=32 * 13, help='pixels') | ||||||
|  |     parser.add_argument('--weights-path', type=str, default='weights', help='path to store weights') | ||||||
|     parser.add_argument('--resume', action='store_true', help='resume training flag') |     parser.add_argument('--resume', action='store_true', help='resume training flag') | ||||||
|     parser.add_argument('--report', action='store_true', help='report TP, FP, FN, P and R per batch (slower)') |     parser.add_argument('--report', action='store_true', help='report TP, FP, FN, P and R per batch (slower)') | ||||||
|     parser.add_argument('--freeze-darknet53', default=False, help='freeze darknet53.conv.74 layers for first epoch') |     parser.add_argument('--freeze-darknet53', default=False, help='freeze darknet53.conv.74 layers for first epoch') | ||||||
|  | @ -241,6 +262,7 @@ if __name__ == '__main__': | ||||||
|         resume=opt.resume, |         resume=opt.resume, | ||||||
|         epochs=opt.epochs, |         epochs=opt.epochs, | ||||||
|         batch_size=opt.batch_size, |         batch_size=opt.batch_size, | ||||||
|  |         weights_path=opt.weights_path, | ||||||
|         report=opt.report, |         report=opt.report, | ||||||
|         multi_scale=opt.multi_scale, |         multi_scale=opt.multi_scale, | ||||||
|         freeze_backbone=opt.freeze_darknet53, |         freeze_backbone=opt.freeze_darknet53, | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue