updates
This commit is contained in:
parent
a0b4d17f7e
commit
0e54731bb8
|
@ -253,21 +253,21 @@ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False):
|
||||||
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
|
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
|
||||||
|
|
||||||
# Intersection area
|
# Intersection area
|
||||||
inter_area = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
|
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
|
||||||
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
||||||
|
|
||||||
# Union Area
|
# Union Area
|
||||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
|
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
|
||||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
|
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
|
||||||
union_area = (w1 * h1 + 1e-16) + w2 * h2 - inter_area
|
union = (w1 * h1 + 1e-16) + w2 * h2 - inter
|
||||||
|
|
||||||
iou = inter_area / union_area # iou
|
iou = inter / union # iou
|
||||||
if GIoU or DIoU or CIoU:
|
if GIoU or DIoU or CIoU:
|
||||||
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
|
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
|
||||||
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
||||||
if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
|
if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
|
||||||
c_area = cw * ch + 1e-16 # convex area
|
c_area = cw * ch + 1e-16 # convex area
|
||||||
return iou - (c_area - union_area) / c_area # GIoU
|
return iou - (c_area - union) / c_area # GIoU
|
||||||
if DIoU or CIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
if DIoU or CIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
|
||||||
# convex diagonal squared
|
# convex diagonal squared
|
||||||
c2 = cw ** 2 + ch ** 2 + 1e-16
|
c2 = cw ** 2 + ch ** 2 + 1e-16
|
||||||
|
@ -297,20 +297,18 @@ def box_iou(boxes1, boxes2):
|
||||||
IoU values for every element in boxes1 and boxes2
|
IoU values for every element in boxes1 and boxes2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def box_area(boxes):
|
def box_area(box):
|
||||||
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
# box = 4xn
|
||||||
|
return (box[2] - box[0]) * (box[3] - box[1])
|
||||||
|
|
||||||
area1 = box_area(boxes1)
|
area1 = box_area(boxes1.t())
|
||||||
area2 = box_area(boxes2)
|
area2 = box_area(boxes2.t())
|
||||||
|
|
||||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||||
|
|
||||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
inter = (rb - lt).clamp(min=0).prod(2) # [N,M]
|
||||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
|
||||||
|
|
||||||
iou = inter / (area1[:, None] + area2 - inter)
|
|
||||||
return iou
|
|
||||||
|
|
||||||
|
|
||||||
def wh_iou(wh1, wh2):
|
def wh_iou(wh1, wh2):
|
||||||
|
|
Loading…
Reference in New Issue