This commit is contained in:
glenn-jocher 2019-07-04 22:50:03 +02:00
parent abf59f1565
commit 7246dd855c
1 changed files with 13 additions and 4 deletions

View File

@ -332,16 +332,25 @@ def build_targets(model, targets):
# iou of targets-anchors
t, a = targets, []
gwh = targets[:, 4:6] * layer.ng
gwh = t[:, 4:6] * layer.ng
if nt:
iou = [wh_iou(x, gwh) for x in layer.anchor_vec]
iou, a = torch.stack(iou, 0).max(0) # best iou and anchor
iou = torch.stack([wh_iou(x, gwh) for x in layer.anchor_vec], 0)
use_best = True
if use_best:
iou, a = iou.max(0) # best iou and anchor
else:
na = len(layer.anchor_vec) # number of anchors
a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1)
t = targets.repeat([na, 1])
gwh = gwh.repeat([na, 1])
iou = iou.view(-1) # use all ious
# reject below threshold ious (OPTIONAL, increases P, lowers R)
reject = True
if reject:
j = iou > iou_thres
t, a, gwh = targets[j], a[j], gwh[j]
t, a, gwh = t[j], a[j], gwh[j]
# Indices
b, c = t[:, :2].long().t() # target image, class