updates
This commit is contained in:
		
							parent
							
								
									c36f1e990b
								
							
						
					
					
						commit
						d79a54a4be
					
				
							
								
								
									
										8
									
								
								train.py
								
								
								
								
							
							
						
						
									
										8
									
								
								train.py
								
								
								
								
							|  | @ -53,12 +53,12 @@ def train( | |||
|     yl = get_yolo_layers(model)  # yolo layers | ||||
|     nf = int(model.module_defs[yl[0] - 1]['filters'])  # yolo layer size (i.e. 255) | ||||
| 
 | ||||
|     if resume:  # Load previously saved PyTorch model | ||||
|     if resume:  # Load previously saved model | ||||
|         if transfer:  # Transfer learning | ||||
|             chkpt = torch.load(weights + 'yolov3.pt', map_location=device) | ||||
|             model.load_state_dict( | ||||
|                 {k: v for k, v in chkpt['model'].items() if (int(k.split('.')[1]) + 1) not in yl}, strict=False) | ||||
|             for (name, p) in model.named_parameters(): | ||||
|             model.load_state_dict({k: v for k, v in chkpt['model'].items() if v.numel() > 1 and v.shape[0] != nf}, | ||||
|                                   strict=False) | ||||
|             for p in model.parameters(): | ||||
|                 p.requires_grad = True if p.shape[0] == nf else False | ||||
| 
 | ||||
|         else:  # resume from latest.pt | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue