diff --git a/utils/utils.py b/utils/utils.py index cc314102..18ccb514 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -313,15 +313,12 @@ def box_iou(boxes1, boxes2): return iou -def wh_iou(box1, box2): - # Returns the IoU of wh1 to wh2. wh1 is 2, wh2 is 2xn - w1, h1 = box1[0], box1[1] - w2, h2 = box2[0], box2[1] - - # Intersection area - inter = torch.min(w1, w2) * torch.min(h1, h2) - - return inter / (w1 * h1 + w2 * h2 - inter) # iou = inter / (area1 + area2 - inter) +def wh_iou(wh1, wh2): + # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2 + wh1 = wh1[:, None] # [N,1,2] + wh2 = wh2[None] # [1,M,2] + inter = torch.min(wh1, wh2).prod(2) # [N,M] + return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter) class FocalLoss(nn.Module): @@ -445,9 +442,8 @@ def build_targets(model, targets): # iou of targets-anchors t, a = targets, [] gwh = t[:, 4:6] * ng - gwht = gwh.t() if nt: - iou = torch.stack([wh_iou(x, gwht) for x in anchor_vec], 0) + iou = wh_iou(anchor_vec, gwh) if use_all_anchors: na = len(anchor_vec) # number of anchors