updates
This commit is contained in:
parent
0bb3fcb049
commit
7283f26f6f
|
@ -222,8 +222,6 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
|
|||
if nTb == 0:
|
||||
continue
|
||||
t = target[b]
|
||||
if batch_report:
|
||||
FN[b, :nTb] = 1
|
||||
|
||||
# Convert to position relative to box
|
||||
TC[b, :nTb], gx, gy, gw, gh = t[:, 0].long(), t[:, 1] * nG, t[:, 2] * nG, t[:, 3] * nG, t[:, 4] * nG
|
||||
|
@ -233,25 +231,25 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
|
|||
|
||||
# iou of targets-anchors (using wh only)
|
||||
box1 = t[:, 3:5] * nG
|
||||
# box2 = anchor_grid_wh[:, gj, gi]
|
||||
box2 = anchor_wh.unsqueeze(1).repeat(1, nTb, 1)
|
||||
box2 = anchor_wh.unsqueeze(1)
|
||||
inter_area = torch.min(box1, box2).prod(2)
|
||||
iou_anch = inter_area / (gw * gh + box2.prod(2) - inter_area + 1e-16)
|
||||
iou = inter_area / (gw * gh + box2.prod(2) - inter_area + 1e-16)
|
||||
|
||||
# Select best iou_pred and anchor
|
||||
iou_anch_best, a = iou_anch.max(0) # best anchor [0-2] for each target
|
||||
iou_best, a = iou.max(0) # best anchor [0-2] for each target
|
||||
|
||||
# Select best unique target-anchor combinations
|
||||
if nTb > 1:
|
||||
iou_order = np.argsort(-iou_anch_best) # best to worst
|
||||
iou_order = torch.argsort(-iou_best) # best to worst
|
||||
|
||||
# Unique anchor selection (slower but retains original order)
|
||||
u = torch.cat((gi, gj, a), 0).view(3, -1).numpy()
|
||||
# Unique anchor selection
|
||||
u = torch.cat((gi, gj, a), 0).view(3, -1)
|
||||
_, first_unique = np.unique(u[:, iou_order], axis=1, return_index=True) # first unique indices
|
||||
# _, first_unique = torch.unique(u[:, iou_order], dim=1, return_inverse=True) # different than numpy?
|
||||
|
||||
i = iou_order[first_unique]
|
||||
# best anchor must share significant commonality (iou) with target
|
||||
i = i[iou_anch_best[i] > 0.10]
|
||||
i = i[iou_best[i] > 0.10]
|
||||
if len(i) == 0:
|
||||
continue
|
||||
|
||||
|
@ -259,7 +257,7 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
|
|||
if len(t.shape) == 1:
|
||||
t = t.view(1, 5)
|
||||
else:
|
||||
if iou_anch_best < 0.10:
|
||||
if iou_best < 0.10:
|
||||
continue
|
||||
i = 0
|
||||
|
||||
|
|
Loading…
Reference in New Issue