updates
This commit is contained in:
parent
b027c66048
commit
4aff400777
|
@ -11,6 +11,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
|
||||
from . import torch_utils # , google_utils
|
||||
|
||||
|
@ -234,7 +235,7 @@ def compute_ap(recall, precision):
|
|||
return ap
|
||||
|
||||
|
||||
def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False):
|
||||
def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False):
|
||||
# Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
|
||||
box2 = box2.t()
|
||||
|
||||
|
@ -255,15 +256,31 @@ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False):
|
|||
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
||||
|
||||
# Union Area
|
||||
union_area = ((b1_x2 - b1_x1) * (b1_y2 - b1_y1) + 1e-16) + \
|
||||
(b2_x2 - b2_x1) * (b2_y2 - b2_y1) - inter_area
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
|
||||
union_area = (w1 * h1 + 1e-16) + w2 * h2 - inter_area
|
||||
|
||||
iou = inter_area / union_area # iou
|
||||
if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
|
||||
c_x1, c_x2 = torch.min(b1_x1, b2_x1), torch.max(b1_x2, b2_x2)
|
||||
c_y1, c_y2 = torch.min(b1_y1, b2_y1), torch.max(b1_y2, b2_y2)
|
||||
c_area = (c_x2 - c_x1) * (c_y2 - c_y1) + 1e-16 # convex area
|
||||
return iou - (c_area - union_area) / c_area # GIoU
|
||||
if GIoU or DIoU or CIoU:
|
||||
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
|
||||
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
|
||||
if GIoU: # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
|
||||
c_area = cw * ch + 1e-16 # convex area
|
||||
return iou - (c_area - union_area) / c_area # GIoU
|
||||
if DIoU or CIoU: # Distance IoU https://arxiv.org/abs/1911.08287v1
|
||||
c2 = cw ** 2 + ch ** 2 + 1e-16 # convex diagonal squared
|
||||
# b1_xc, b1_yc = (b1_x1 + b1_x2) / 2, (b1_y1 + b1_y2) / 2
|
||||
# b2_xc, b2_yc = (b2_x1 + b2_x2) / 2, (b2_y1 + b2_y2) / 2
|
||||
# rho2 = (b2_xc - b1_xc) ** 2 + (b2_yc - b1_yc) ** 2 # centerpoint distance squared
|
||||
rho2 = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2)) ** 2 / 4 + ((b2_y1 + b2_y2) - (b1_y1 + b1_y2)) ** 2 / 4
|
||||
if DIoU:
|
||||
return iou - rho2 / c2 # DIoU
|
||||
elif CIoU:
|
||||
atan = torch.atan(w2 / h2) - torch.atan(w1 / h1)
|
||||
v = (4 / math.pi ** 2) * torch.pow(atan, 2)
|
||||
alpha = v / (1 - iou + v)
|
||||
# ar = - (8 / (math.pi ** 2)) * atan * (w1 * h1)
|
||||
return iou - (rho2 / c2 + alpha * v) # CIoU
|
||||
|
||||
return iou
|
||||
|
||||
|
|
Loading…
Reference in New Issue