This commit is contained in:
Glenn Jocher 2019-11-24 20:08:24 -10:00
parent 7773651e8e
commit 9b55bbf9e2
1 changed files with 6 additions and 10 deletions

View File

@ -267,20 +267,16 @@ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False):
if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
c_area = cw * ch + 1e-16 # convex area c_area = cw * ch + 1e-16 # convex area
return iou - (c_area - union_area) / c_area # GIoU return iou - (c_area - union_area) / c_area # GIoU
if DIoU or CIoU: # Distance IoU https://arxiv.org/abs/1911.08287v1 if DIoU or CIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
c2 = cw ** 2 + ch ** 2 + 1e-16 # convex diagonal squared # convex diagonal squared
# b1_xc, b1_yc = (b1_x1 + b1_x2) / 2, (b1_y1 + b1_y2) / 2 c2 = cw ** 2 + ch ** 2 + 1e-16
# b2_xc, b2_yc = (b2_x1 + b2_x2) / 2, (b2_y1 + b2_y2) / 2 # centerpoint distance squared
# rho2 = (b2_xc - b1_xc) ** 2 + (b2_yc - b1_yc) ** 2 # centerpoint distance squared
rho2 = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2)) ** 2 / 4 + ((b2_y1 + b2_y2) - (b1_y1 + b1_y2)) ** 2 / 4 rho2 = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2)) ** 2 / 4 + ((b2_y1 + b2_y2) - (b1_y1 + b1_y2)) ** 2 / 4
if DIoU: if DIoU:
return iou - rho2 / c2 # DIoU return iou - rho2 / c2 # DIoU
elif CIoU: elif CIoU:
atan = torch.atan(w2 / h2) - torch.atan(w1 / h1) v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
v = (4 / math.pi ** 2) * torch.pow(atan, 2) return iou - (rho2 / c2 + v ** 2 / (1 - iou + v)) # CIoU
alpha = v / (1 - iou + v)
# ar = - (8 / (math.pi ** 2)) * atan * (w1 * h1)
return iou - (rho2 / c2 + alpha * v) # CIoU
return iou return iou