updates
This commit is contained in:
parent
654b9834c2
commit
a0b4d17f7e
|
@ -313,15 +313,12 @@ def box_iou(boxes1, boxes2):
|
||||||
return iou
|
return iou
|
||||||
|
|
||||||
|
|
||||||
def wh_iou(box1, box2):
|
def wh_iou(wh1, wh2):
|
||||||
# Returns the IoU of wh1 to wh2. wh1 is 2, wh2 is 2xn
|
# Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
|
||||||
w1, h1 = box1[0], box1[1]
|
wh1 = wh1[:, None] # [N,1,2]
|
||||||
w2, h2 = box2[0], box2[1]
|
wh2 = wh2[None] # [1,M,2]
|
||||||
|
inter = torch.min(wh1, wh2).prod(2) # [N,M]
|
||||||
# Intersection area
|
return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter)
|
||||||
inter = torch.min(w1, w2) * torch.min(h1, h2)
|
|
||||||
|
|
||||||
return inter / (w1 * h1 + w2 * h2 - inter) # iou = inter / (area1 + area2 - inter)
|
|
||||||
|
|
||||||
|
|
||||||
class FocalLoss(nn.Module):
|
class FocalLoss(nn.Module):
|
||||||
|
@ -445,9 +442,8 @@ def build_targets(model, targets):
|
||||||
# iou of targets-anchors
|
# iou of targets-anchors
|
||||||
t, a = targets, []
|
t, a = targets, []
|
||||||
gwh = t[:, 4:6] * ng
|
gwh = t[:, 4:6] * ng
|
||||||
gwht = gwh.t()
|
|
||||||
if nt:
|
if nt:
|
||||||
iou = torch.stack([wh_iou(x, gwht) for x in anchor_vec], 0)
|
iou = wh_iou(anchor_vec, gwh)
|
||||||
|
|
||||||
if use_all_anchors:
|
if use_all_anchors:
|
||||||
na = len(anchor_vec) # number of anchors
|
na = len(anchor_vec) # number of anchors
|
||||||
|
|
Loading…
Reference in New Issue