updates
This commit is contained in:
		
							parent
							
								
									bca423ee43
								
							
						
					
					
						commit
						81b4a7833f
					
				|  | @ -219,7 +219,7 @@ def compute_ap(recall, precision): | |||
|     return ap | ||||
| 
 | ||||
| 
 | ||||
| def bbox_iou(box1, box2, x1y1x2y2=True): | ||||
| def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False): | ||||
|     # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 | ||||
|     box2 = box2.t() | ||||
| 
 | ||||
|  | @ -243,7 +243,14 @@ def bbox_iou(box1, box2, x1y1x2y2=True): | |||
|     union_area = ((b1_x2 - b1_x1) * (b1_y2 - b1_y1) + 1e-16) + \ | ||||
|                  (b2_x2 - b2_x1) * (b2_y2 - b2_y1) - inter_area | ||||
| 
 | ||||
|     return inter_area / union_area  # iou | ||||
|     iou = inter_area / union_area  # iou | ||||
|     if GIoU:  # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf | ||||
|         c_x1, c_x2 = torch.min(b1_x1, b2_x1), torch.max(b1_x2, b2_x2) | ||||
|         c_y1, c_y2 = torch.min(b1_y1, b2_y1), torch.max(b1_y2, b2_y2) | ||||
|         c_area = (c_x2 - c_x1) * (c_y2 - c_y1)  # convex area | ||||
|         return iou - (c_area - union_area) / c_area  # GIoU | ||||
| 
 | ||||
|     return iou | ||||
| 
 | ||||
| 
 | ||||
| def wh_iou(box1, box2): | ||||
|  | @ -265,8 +272,8 @@ def wh_iou(box1, box2): | |||
| 
 | ||||
| def compute_loss(p, targets, model):  # predictions, targets, model | ||||
|     ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor | ||||
|     lxy, lwh, lcls, lconf = ft([0]), ft([0]), ft([0]), ft([0]) | ||||
|     txy, twh, tcls, indices = build_targets(model, targets) | ||||
|     lxy, lwh, lcls, lconf, lgiou = ft([0]), ft([0]), ft([0]), ft([0]), ft([0]) | ||||
|     txy, twh, tcls, tbox, indices, anchor_vec = build_targets(model, targets) | ||||
| 
 | ||||
|     # Define criteria | ||||
|     MSE = nn.MSELoss() | ||||
|  | @ -287,6 +294,11 @@ def compute_loss(p, targets, model):  # predictions, targets, model | |||
|             tconf[b, a, gj, gi] = 1  # conf | ||||
|             # pi[..., 2:4] = torch.sigmoid(pi[..., 2:4])  # wh power loss (uncomment) | ||||
| 
 | ||||
|             # Build GIoU boxes | ||||
|             pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1)  # predicted box | ||||
|             giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) | ||||
| 
 | ||||
|             # lxy += (k * h['giou']) * (1.0 - giou).mean()  # giou loss | ||||
|             lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i])  # xy loss | ||||
|             lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i])  # wh yolo loss | ||||
|             lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i])  # class_conf loss | ||||
|  | @ -306,7 +318,7 @@ def build_targets(model, targets): | |||
|         model = model.module | ||||
| 
 | ||||
|     nt = len(targets) | ||||
|     txy, twh, tcls, indices = [], [], [], [] | ||||
|     txy, twh, tcls, tbox, indices, anchor_vec = [], [], [], [], [], [] | ||||
|     for i in model.yolo_layers: | ||||
|         layer = model.module_list[i][0] | ||||
| 
 | ||||
|  | @ -330,7 +342,12 @@ def build_targets(model, targets): | |||
|         indices.append((b, a, gj, gi)) | ||||
| 
 | ||||
|         # XY coordinates | ||||
|         txy.append(gxy - gxy.floor()) | ||||
|         gxy -= gxy.floor() | ||||
|         txy.append(gxy) | ||||
| 
 | ||||
|         # GIoU | ||||
|         tbox.append(torch.cat((gxy, gwh), 1))  # xywh (grids) | ||||
|         anchor_vec.append(layer.anchor_vec[a]) | ||||
| 
 | ||||
|         # Width and height | ||||
|         twh.append(torch.log(gwh / layer.anchor_vec[a]))  # wh yolo method | ||||
|  | @ -341,7 +358,7 @@ def build_targets(model, targets): | |||
|         if c.shape[0]: | ||||
|             assert c.max() <= layer.nc, 'Target classes exceed model classes' | ||||
| 
 | ||||
|     return txy, twh, tcls, indices | ||||
|     return txy, twh, tcls, tbox, indices, anchor_vec | ||||
| 
 | ||||
| 
 | ||||
| def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue