update loss components
This commit is contained in:
		
							parent
							
								
									bd3f617129
								
							
						
					
					
						commit
						cf9b4cfa52
					
				
							
								
								
									
										36
									
								
								models.py
								
								
								
								
							
							
						
						
									
										36
									
								
								models.py
								
								
								
								
							|  | @ -137,10 +137,9 @@ class YOLOLayer(nn.Module): | ||||||
| 
 | 
 | ||||||
|         # Training |         # Training | ||||||
|         if targets is not None: |         if targets is not None: | ||||||
|             BCEWithLogitsLoss1 = nn.BCEWithLogitsLoss(size_average=False) |             MSELoss = nn.MSELoss() | ||||||
|             BCEWithLogitsLoss2 = nn.BCEWithLogitsLoss(size_average=True) |             BCEWithLogitsLoss = nn.BCEWithLogitsLoss() | ||||||
|             MSELoss = nn.MSELoss(size_average=False)  # version 0.4.0 |             # CrossEntropyLoss = nn.CrossEntropyLoss() | ||||||
|             CrossEntropyLoss = nn.CrossEntropyLoss() |  | ||||||
| 
 | 
 | ||||||
|             if requestPrecision: |             if requestPrecision: | ||||||
|                 gx = self.grid_x[:, :, :nG, :nG] |                 gx = self.grid_x[:, :, :nG, :nG] | ||||||
|  | @ -161,29 +160,32 @@ class YOLOLayer(nn.Module): | ||||||
|             nT = sum([len(x) for x in targets])  # number of targets |             nT = sum([len(x) for x in targets])  # number of targets | ||||||
|             nM = mask.sum().float()  # number of anchors (assigned to targets) |             nM = mask.sum().float()  # number of anchors (assigned to targets) | ||||||
|             nB = len(targets)  # batch size |             nB = len(targets)  # batch size | ||||||
|  |             k = nM / nB | ||||||
|             if nM > 0: |             if nM > 0: | ||||||
|                     lx = (5 / nB) * MSELoss(x[mask], tx[mask]) |                 lx = k * MSELoss(x[mask], tx[mask]) | ||||||
|                     ly = (5 / nB) * MSELoss(y[mask], ty[mask]) |                 ly = k * MSELoss(y[mask], ty[mask]) | ||||||
|                     lw = (5 / nB) * MSELoss(w[mask], tw[mask]) |                 lw = k * MSELoss(w[mask], tw[mask]) | ||||||
|                     lh = (5 / nB) * MSELoss(h[mask], th[mask]) |                 lh = k * MSELoss(h[mask], th[mask]) | ||||||
|                     lconf = (1 / nB) * BCEWithLogitsLoss1(pred_conf[mask], mask[mask].float()) |                 lconf = k * BCEWithLogitsLoss(pred_conf[mask], mask[mask].float()) | ||||||
| 
 | 
 | ||||||
|                     lcls = (1 * nM / nB) * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1)) |                 # lcls = k * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1)) | ||||||
|                     # lcls = (1 * nM / nB) * BCEWithLogitsLoss2(pred_cls[mask], tcls.float()) |                 lcls = k * BCEWithLogitsLoss(pred_cls[mask], tcls.float()) | ||||||
|             else: |             else: | ||||||
|                 lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0]) |                 lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0]) | ||||||
| 
 | 
 | ||||||
|                 lconf += (0.5 * nM / nB) * BCEWithLogitsLoss2(pred_conf[~mask], mask[~mask].float()) |             # Add confidence loss for background anchors (noobj) | ||||||
|  |             lconf += k * BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float()) | ||||||
| 
 | 
 | ||||||
|  |             # Sum loss components | ||||||
|             loss = lx + ly + lw + lh + lconf + lcls |             loss = lx + ly + lw + lh + lconf + lcls | ||||||
| 
 | 
 | ||||||
|             # Sum False Positives from unnasigned anchors |             # Sum False Positives from unassigned anchors | ||||||
|             i = torch.sigmoid(pred_conf[~mask]) > 0.99 |             i = torch.sigmoid(pred_conf[~mask]) > 0.9 | ||||||
|             FPe = torch.zeros(self.nC) |  | ||||||
|             if i.sum() > 0: |             if i.sum() > 0: | ||||||
|                 FP_classes = torch.argmax(pred_cls[~mask][i], 1) |                 FP_classes = torch.argmax(pred_cls[~mask][i], 1) | ||||||
|                 for c in FP_classes: |                 FPe = torch.bincount(FP_classes, minlength=self.nC).float().cpu()  # extra FPs | ||||||
|                     FPe[c] += 1 |             else: | ||||||
|  |                 FPe = torch.zeros(self.nC) | ||||||
| 
 | 
 | ||||||
|             return loss, loss.item(), lx.item(), ly.item(), lw.item(), lh.item(), lconf.item(), lcls.item(), \ |             return loss, loss.item(), lx.item(), ly.item(), lw.item(), lh.item(), lconf.item(), lcls.item(), \ | ||||||
|                    nT, TP, FP, FPe, FN, TC |                    nT, TP, FP, FPe, FN, TC | ||||||
|  |  | ||||||
|  | @ -11,7 +11,7 @@ gsutil cp gs://ultralytics/fresh9_5_e201.pt yolov3/checkpoints | ||||||
| python3 detect.py | python3 detect.py | ||||||
| 
 | 
 | ||||||
| # Test | # Test | ||||||
| python3 test.py -img_size 416 -weights_path checkpoints/yolov3.weights | python3 test.py -img_size 416 -weights_path checkpoints/latest.pt -conf_thresh 0.5 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # Download and Test | # Download and Test | ||||||
|  |  | ||||||
|  | @ -282,9 +282,9 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG | ||||||
|             pconf = torch.sigmoid(pred_conf[b, a, gj, gi]).cpu() |             pconf = torch.sigmoid(pred_conf[b, a, gj, gi]).cpu() | ||||||
|             iou_pred = bbox_iou(tb, pred_boxes[b, a, gj, gi].cpu()) |             iou_pred = bbox_iou(tb, pred_boxes[b, a, gj, gi].cpu()) | ||||||
| 
 | 
 | ||||||
|             TP[b, i] = (pconf > 0.99) & (iou_pred > 0.5) & (pcls == tc) |             TP[b, i] = (pconf > 0.9) & (iou_pred > 0.5) & (pcls == tc) | ||||||
|             FP[b, i] = (pconf > 0.99) & (TP[b, i] == 0)  # coordinates or class are wrong |             FP[b, i] = (pconf > 0.9) & (TP[b, i] == 0)  # coordinates or class are wrong | ||||||
|             FN[b, i] = pconf <= 0.99  # confidence score is too low (set to zero) |             FN[b, i] = pconf <= 0.9  # confidence score is too low (set to zero) | ||||||
| 
 | 
 | ||||||
|     return tx, ty, tw, th, tconf, tcls, TP, FP, FN, TC |     return tx, ty, tw, th, tconf, tcls, TP, FP, FN, TC | ||||||
| 
 | 
 | ||||||
|  | @ -429,8 +429,8 @@ def plotResults(): | ||||||
|     import matplotlib.pyplot as plt |     import matplotlib.pyplot as plt | ||||||
|     plt.figure(figsize=(16, 8)) |     plt.figure(figsize=(16, 8)) | ||||||
|     s = ['X', 'Y', 'Width', 'Height', 'Objectness', 'Classification', 'Total Loss', 'Precision', 'Recall'] |     s = ['X', 'Y', 'Width', 'Height', 'Objectness', 'Classification', 'Total Loss', 'Precision', 'Recall'] | ||||||
|     for f in ('/Users/glennjocher/Downloads/results_CE.txt', |     for f in ('results.txt', | ||||||
|               '/Users/glennjocher/Downloads/results_BCE.txt'): |               ): | ||||||
|         results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 7, 8, 9, 10]).T |         results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 7, 8, 9, 10]).T | ||||||
|         for i in range(9): |         for i in range(9): | ||||||
|             plt.subplot(2, 5, i + 1) |             plt.subplot(2, 5, i + 1) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue