updates
This commit is contained in:
		
							parent
							
								
									a021f97110
								
							
						
					
					
						commit
						d2c5d7a5fd
					
				
							
								
								
									
										14
									
								
								models.py
								
								
								
								
							
							
						
						
									
										14
									
								
								models.py
								
								
								
								
							|  | @ -101,6 +101,9 @@ class YOLOLayer(nn.Module): | |||
|         self.anchor_h = self.scaled_anchors[:, 1:2].view((1, nA, 1, 1)) | ||||
|         self.weights = class_weights() | ||||
| 
 | ||||
|         self.batch_count = 0 | ||||
|         self.loss_means = torch.zeros(6) | ||||
| 
 | ||||
|     def forward(self, p, targets=None, requestPrecision=False): | ||||
|         FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor | ||||
| 
 | ||||
|  | @ -139,6 +142,7 @@ class YOLOLayer(nn.Module): | |||
| 
 | ||||
|         # Training | ||||
|         if targets is not None: | ||||
|             self.batch_count += 1 | ||||
|             MSELoss = nn.MSELoss() | ||||
|             BCEWithLogitsLoss = nn.BCEWithLogitsLoss() | ||||
|             CrossEntropyLoss = nn.CrossEntropyLoss() | ||||
|  | @ -181,7 +185,15 @@ class YOLOLayer(nn.Module): | |||
|             # lconf += k * BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float()) | ||||
| 
 | ||||
|             # Sum loss components | ||||
|             loss = lx + ly + lw + lh + lconf + lcls | ||||
|             balance_losses_flag = True | ||||
|             if balance_losses_flag: | ||||
|                 loss_vec = torch.FloatTensor([lx.data, ly.data, lw.data, lh.data, lconf.data, lcls.data]) | ||||
|                 self.loss_means = self.loss_means * 0.99 + loss_vec * 0.01 | ||||
|                 k = 1 / self.loss_means.clone() | ||||
|                 k /= k.sum() | ||||
|                 loss = (lx * k[0] + ly * k[1] + lw * k[2] + lh * k[3] + lconf * k[4] + lcls * k[5]) * loss_vec.sum() | ||||
|             else: | ||||
|                 loss = lx + ly + lw + lh + lconf + lcls | ||||
| 
 | ||||
|             # Sum False Positives from unassigned anchors | ||||
|             i = torch.sigmoid(pred_conf[~mask]) > 0.5 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue