From 0772ebf7c9f2cf253b51299215db68a44c1c3671 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 19 Feb 2019 22:19:59 +0100 Subject: [PATCH] xy and wh losses respectively merged --- utils/utils.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index e6c1a33d..3a0074b0 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -231,18 +231,16 @@ def build_targets(target, anchor_wh, nA, nC, nG): continue t = target[b] - # Convert to position relative to box - gx, gy, gw, gh = t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG + gxy, gwh = t[:, 1:3] * nG, t[:, 3:5] * nG # Get grid box indices and prevent overflows (i.e. 13.01 on 13 anchors) - gi = torch.clamp(gx.long(), min=0, max=nG - 1) - gj = torch.clamp(gy.long(), min=0, max=nG - 1) + gi, gj = torch.clamp(gxy.long(), min=0, max=nG - 1).t() # iou of targets-anchors (using wh only) - box1 = t[:, 3:5] * nG + box1 = gwh box2 = anchor_wh.unsqueeze(1) inter_area = torch.min(box1, box2).prod(2) - iou = inter_area / (gw * gh + box2.prod(2) - inter_area + 1e-16) + iou = inter_area / (box1.prod(1) + box2.prod(2) - inter_area + 1e-16) # Select best iou_pred and anchor iou_best, a = iou.max(0) # best anchor [0-2] for each target @@ -269,17 +267,14 @@ def build_targets(target, anchor_wh, nA, nC, nG): if iou_best < 0.10: continue - tc, gx, gy, gwh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3:5] * nG + tc, gxy, gwh = t[:, 0].long(), t[:, 1:3] * nG, t[:, 3:5] * nG - # Coordinates - txy[b, a, gj, gi, 0] = gx - gi.float() - txy[b, a, gj, gi, 1] = gy - gj.float() + # XY coordinates + txy[b, a, gj, gi] = gxy - gxy.floor() - # Width and height (yolo method) - twh[b, a, gj, gi] = torch.log(gwh / anchor_wh[a]) - - # Width and height (power method) - # twh[b, a, gj, gi] = torch.sqrt(gwh / anchor_wh[a]) / 2 + # Width and height + twh[b, a, gj, gi] = torch.log(gwh / anchor_wh[a]) # yolo method + # twh[b, a, gj, gi] = torch.sqrt(gwh / anchor_wh[a]) / 2 # power method # One-hot encoding of label tcls[b, a, gj, gi, tc] = 1