Update train.py
solve the Multi-GPU --resume Error #138 https://github.com/ultralytics/yolov3/issues/138
This commit is contained in:
		
							parent
							
								
									2cd6805063
								
							
						
					
					
						commit
						35396adc9c
					
				
							
								
								
									
										4
									
								
								train.py
								
								
								
								
							
							
						
						
									
										4
									
								
								train.py
								
								
								
								
							|  | @ -32,7 +32,7 @@ def train( | ||||||
|     train_path = parse_data_cfg(data_cfg)['train'] |     train_path = parse_data_cfg(data_cfg)['train'] | ||||||
| 
 | 
 | ||||||
|     # Initialize model |     # Initialize model | ||||||
|     model = Darknet(cfg, img_size) |     model = Darknet(cfg, img_size).to(device) | ||||||
| 
 | 
 | ||||||
|     # Get dataloader |     # Get dataloader | ||||||
|     dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True) |     dataloader = LoadImagesAndLabels(train_path, batch_size, img_size, augment=True) | ||||||
|  | @ -43,7 +43,7 @@ def train( | ||||||
|     start_epoch = 0 |     start_epoch = 0 | ||||||
|     best_loss = float('inf') |     best_loss = float('inf') | ||||||
|     if resume: |     if resume: | ||||||
|         checkpoint = torch.load(latest, map_location='cpu') |         checkpoint = torch.load(latest, map_location=device) | ||||||
| 
 | 
 | ||||||
|         # Load weights to resume from |         # Load weights to resume from | ||||||
|         model.load_state_dict(checkpoint['model']) |         model.load_state_dict(checkpoint['model']) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue