Non-output layer freeze in train.py (#1333)
Freeze layers that aren't of type YOLOLayer and that aren't the conv layers preceeding them
This commit is contained in:
		
							parent
							
								
									ca7794ed05
								
							
						
					
					
						commit
						a97f350461
					
				
							
								
								
									
										11
									
								
								train.py
								
								
								
								
							
							
						
						
									
										11
									
								
								train.py
								
								
								
								
							|  | @ -143,6 +143,16 @@ def train(hyp): | |||
|     elif len(weights) > 0:  # darknet format | ||||
|         # possible weights are '*.weights', 'yolov3-tiny.conv.15',  'darknet53.conv.74' etc. | ||||
|         load_darknet_weights(model, weights) | ||||
|      | ||||
|     if opt.freeze_layers:                                                                                                                                                             | ||||
|         output_layer_indices = [idx - 1 for idx, module in enumerate(model.module_list) \                                                                                             | ||||
|                              if isinstance(module, YOLOLayer)]                                                                                                                       | ||||
|         freeze_layer_indices = [x for x in range(len(model.module_list)) if\                                                                                                          | ||||
|                                 (x not in output_layer_indices) and \                                                                                                                 | ||||
|                                 (x - 1 not in output_layer_indices)]                                                                                                                  | ||||
|         for idx in freeze_layer_indices:                                                                                                                                              | ||||
|             for parameter in model.module_list[idx].parameters():                                                                                                                     | ||||
|                 parameter.requires_grad_(False)                                                                                                                                       | ||||
| 
 | ||||
|     # Mixed precision training https://github.com/NVIDIA/apex | ||||
|     if mixed_precision: | ||||
|  | @ -394,6 +404,7 @@ if __name__ == '__main__': | |||
|     parser.add_argument('--device', default='', help='device id (i.e. 0 or 0,1 or cpu)') | ||||
|     parser.add_argument('--adam', action='store_true', help='use adam optimizer') | ||||
|     parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') | ||||
|     parser.add_argument('--freeze-layers', action='store_true', help='Freeze non-output layers')   | ||||
|     opt = parser.parse_args() | ||||
|     opt.weights = last if opt.resume else opt.weights | ||||
|     check_git_status() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue