This commit is contained in:
Glenn Jocher 2019-01-01 17:52:45 +01:00
parent 0bb3fcb049
commit 7283f26f6f
1 changed files with 9 additions and 11 deletions

View File

@ -222,8 +222,6 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
if nTb == 0: if nTb == 0:
continue continue
t = target[b] t = target[b]
if batch_report:
FN[b, :nTb] = 1
# Convert to position relative to box # Convert to position relative to box
TC[b, :nTb], gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG TC[b, :nTb], gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
@ -233,25 +231,25 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
# iou of targets-anchors (using wh only) # iou of targets-anchors (using wh only)
box1 = t[:, 3:5] * nG box1 = t[:, 3:5] * nG
# box2 = anchor_grid_wh[:, gj, gi] box2 = anchor_wh.unsqueeze(1)
box2 = anchor_wh.unsqueeze(1).repeat(1, nTb, 1)
inter_area = torch.min(box1, box2).prod(2) inter_area = torch.min(box1, box2).prod(2)
iou_anch = inter_area / (gw * gh + box2.prod(2) - inter_area + 1e-16) iou = inter_area / (gw * gh + box2.prod(2) - inter_area + 1e-16)
# Select best iou_pred and anchor # Select best iou_pred and anchor
iou_anch_best, a = iou_anch.max(0) # best anchor [0-2] for each target iou_best, a = iou.max(0) # best anchor [0-2] for each target
# Select best unique target-anchor combinations # Select best unique target-anchor combinations
if nTb > 1: if nTb > 1:
iou_order = np.argsort(-iou_anch_best) # best to worst iou_order = torch.argsort(-iou_best) # best to worst
# Unique anchor selection (slower but retains original order) # Unique anchor selection
u = torch.cat((gi, gj, a), 0).view(3, -1).numpy() u = torch.cat((gi, gj, a), 0).view(3, -1)
_, first_unique = np.unique(u[:, iou_order], axis=1, return_index=True) # first unique indices _, first_unique = np.unique(u[:, iou_order], axis=1, return_index=True) # first unique indices
# _, first_unique = torch.unique(u[:, iou_order], dim=1, return_inverse=True) # different than numpy?
i = iou_order[first_unique] i = iou_order[first_unique]
# best anchor must share significant commonality (iou) with target # best anchor must share significant commonality (iou) with target
i = i[iou_anch_best[i] > 0.10] i = i[iou_best[i] > 0.10]
if len(i) == 0: if len(i) == 0:
continue continue
@ -259,7 +257,7 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
if len(t.shape) == 1: if len(t.shape) == 1:
t = t.view(1, 5) t = t.view(1, 5)
else: else:
if iou_anch_best < 0.10: if iou_best < 0.10:
continue continue
i = 0 i = 0