updates
This commit is contained in:
		
							parent
							
								
									db515a4535
								
							
						
					
					
						commit
						f18f288990
					
				
							
								
								
									
										14
									
								
								models.py
								
								
								
								
							
							
						
						
									
										14
									
								
								models.py
								
								
								
								
							|  | @ -194,12 +194,12 @@ class YOLOLayer(nn.Module): | |||
|                 loss = lx + ly + lw + lh + lconf + lcls | ||||
| 
 | ||||
|             # Sum False Positives from unassigned anchors | ||||
|             i = torch.sigmoid(pred_conf[~mask]) > 0.5 | ||||
|             if i.sum() > 0: | ||||
|                 FP_classes = torch.argmax(pred_cls[~mask][i], 1) | ||||
|                 FPe = torch.bincount(FP_classes, minlength=self.nC).float().cpu()  # extra FPs | ||||
|             else: | ||||
|                 FPe = torch.zeros(self.nC) | ||||
|             FPe = torch.zeros(self.nC) | ||||
|             if requestPrecision: | ||||
|                 i = torch.sigmoid(pred_conf[~mask]) > 0.5 | ||||
|                 if i.sum() > 0: | ||||
|                     FP_classes = torch.argmax(pred_cls[~mask][i], 1) | ||||
|                     FPe = torch.bincount(FP_classes, minlength=self.nC).float().cpu()  # extra FPs | ||||
| 
 | ||||
|             return loss, loss.item(), lx.item(), ly.item(), lw.item(), lh.item(), lconf.item(), lcls.item(), \ | ||||
|                    nT, TP, FP, FPe, FN, TC | ||||
|  | @ -254,7 +254,7 @@ class Darknet(nn.Module): | |||
|                 output.append(x) | ||||
|             layer_outputs.append(x) | ||||
| 
 | ||||
|         if is_training: | ||||
|         if is_training and requestPrecision: | ||||
|             self.losses['nT'] /= 3 | ||||
|             self.losses['TC'] /= 3  # target category | ||||
|             metrics = torch.zeros(3, len(self.losses['FPe']))  # TP, FP, FN | ||||
|  |  | |||
|  | @ -214,7 +214,8 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG | |||
|         if nTb == 0: | ||||
|             continue | ||||
|         t = target[b] | ||||
|         FN[b, :nTb] = 1 | ||||
|         if requestPrecision: | ||||
|             FN[b, :nTb] = 1 | ||||
| 
 | ||||
|         # Convert to position relative to box | ||||
|         TC[b, :nTb], gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue