updates
This commit is contained in:
		
							parent
							
								
									a021f97110
								
							
						
					
					
						commit
						d2c5d7a5fd
					
				
							
								
								
									
										14
									
								
								models.py
								
								
								
								
							
							
						
						
									
										14
									
								
								models.py
								
								
								
								
							| 
						 | 
					@ -101,6 +101,9 @@ class YOLOLayer(nn.Module):
 | 
				
			||||||
        self.anchor_h = self.scaled_anchors[:, 1:2].view((1, nA, 1, 1))
 | 
					        self.anchor_h = self.scaled_anchors[:, 1:2].view((1, nA, 1, 1))
 | 
				
			||||||
        self.weights = class_weights()
 | 
					        self.weights = class_weights()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.batch_count = 0
 | 
				
			||||||
 | 
					        self.loss_means = torch.zeros(6)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, p, targets=None, requestPrecision=False):
 | 
					    def forward(self, p, targets=None, requestPrecision=False):
 | 
				
			||||||
        FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
 | 
					        FT = torch.cuda.FloatTensor if p.is_cuda else torch.FloatTensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -139,6 +142,7 @@ class YOLOLayer(nn.Module):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Training
 | 
					        # Training
 | 
				
			||||||
        if targets is not None:
 | 
					        if targets is not None:
 | 
				
			||||||
 | 
					            self.batch_count += 1
 | 
				
			||||||
            MSELoss = nn.MSELoss()
 | 
					            MSELoss = nn.MSELoss()
 | 
				
			||||||
            BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
 | 
					            BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
 | 
				
			||||||
            CrossEntropyLoss = nn.CrossEntropyLoss()
 | 
					            CrossEntropyLoss = nn.CrossEntropyLoss()
 | 
				
			||||||
| 
						 | 
					@ -181,7 +185,15 @@ class YOLOLayer(nn.Module):
 | 
				
			||||||
            # lconf += k * BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float())
 | 
					            # lconf += k * BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Sum loss components
 | 
					            # Sum loss components
 | 
				
			||||||
            loss = lx + ly + lw + lh + lconf + lcls
 | 
					            balance_losses_flag = True
 | 
				
			||||||
 | 
					            if balance_losses_flag:
 | 
				
			||||||
 | 
					                loss_vec = torch.FloatTensor([lx.data, ly.data, lw.data, lh.data, lconf.data, lcls.data])
 | 
				
			||||||
 | 
					                self.loss_means = self.loss_means * 0.99 + loss_vec * 0.01
 | 
				
			||||||
 | 
					                k = 1 / self.loss_means.clone()
 | 
				
			||||||
 | 
					                k /= k.sum()
 | 
				
			||||||
 | 
					                loss = (lx * k[0] + ly * k[1] + lw * k[2] + lh * k[3] + lconf * k[4] + lcls * k[5]) * loss_vec.sum()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                loss = lx + ly + lw + lh + lconf + lcls
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Sum False Positives from unassigned anchors
 | 
					            # Sum False Positives from unassigned anchors
 | 
				
			||||||
            i = torch.sigmoid(pred_conf[~mask]) > 0.5
 | 
					            i = torch.sigmoid(pred_conf[~mask]) > 0.5
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue