updates
This commit is contained in:
		
							parent
							
								
									8f1becd55c
								
							
						
					
					
						commit
						d25190e15b
					
				
							
								
								
									
										6
									
								
								train.py
								
								
								
								
							
							
						
						
									
										6
									
								
								train.py
								
								
								
								
							|  | @ -64,7 +64,9 @@ def train( | |||
|         torch.backends.cudnn.benchmark = True  # unsuitable for multiscale | ||||
| 
 | ||||
|     # Configure run | ||||
|     train_path = parse_data_cfg(data_cfg)['train'] | ||||
|     data_cfg = parse_data_cfg(data_cfg) | ||||
|     train_path = data_cfg['train'] | ||||
|     nc = data_cfg['classes']  # number of classes | ||||
| 
 | ||||
|     # Initialize model | ||||
|     model = Darknet(cfg, img_size).to(device) | ||||
|  | @ -145,7 +147,7 @@ def train( | |||
|     # Start training | ||||
|     t, t0 = time.time(), time.time() | ||||
|     model.hyp = hyp  # attach hyperparameters to model | ||||
|     model.class_weights = labels_to_class_weights(dataset.labels).to(device)  # attach class weights | ||||
|     model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device)  # attach class weights | ||||
|     model_info(model) | ||||
|     nb = len(dataloader) | ||||
|     results = (0, 0, 0, 0, 0)  # P, R, mAP, F1, test_loss | ||||
|  |  | |||
|  | @ -49,11 +49,11 @@ def model_info(model): | |||
|     print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g)) | ||||
| 
 | ||||
| 
 | ||||
| def labels_to_class_weights(labels): | ||||
| def labels_to_class_weights(labels, nc=80): | ||||
|     # Get class weights (inverse frequency) from training labels | ||||
|     labels = np.concatenate(labels, 0)  # labels.shape = (866643, 5) for COCO | ||||
|     classes = labels[:, 0].astype(np.int) | ||||
|     weights = 1 / (np.bincount(classes, minlength=classes.max() + 1) + 1e-6)  # number of targets per class | ||||
|     weights = 1 / (np.bincount(classes, minlength=nc) + 1e-6)  # number of targets per class | ||||
|     weights /= weights.sum() | ||||
|     return torch.Tensor(weights) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue