This commit is contained in:
glenn-jocher 2019-07-04 22:50:03 +02:00
parent abf59f1565
commit 7246dd855c
1 changed files with 13 additions and 4 deletions

View File

@ -332,16 +332,25 @@ def build_targets(model, targets):
# iou of targets-anchors # iou of targets-anchors
t, a = targets, [] t, a = targets, []
gwh = targets[:, 4:6] * layer.ng gwh = t[:, 4:6] * layer.ng
if nt: if nt:
iou = [wh_iou(x, gwh) for x in layer.anchor_vec] iou = torch.stack([wh_iou(x, gwh) for x in layer.anchor_vec], 0)
iou, a = torch.stack(iou, 0).max(0) # best iou and anchor
use_best = True
if use_best:
iou, a = iou.max(0) # best iou and anchor
else:
na = len(layer.anchor_vec) # number of anchors
a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1)
t = targets.repeat([na, 1])
gwh = gwh.repeat([na, 1])
iou = iou.view(-1) # use all ious
# reject below threshold ious (OPTIONAL, increases P, lowers R) # reject below threshold ious (OPTIONAL, increases P, lowers R)
reject = True reject = True
if reject: if reject:
j = iou > iou_thres j = iou > iou_thres
t, a, gwh = targets[j], a[j], gwh[j] t, a, gwh = t[j], a[j], gwh[j]
# Indices # Indices
b, c = t[:, :2].long().t() # target image, class b, c = t[:, :2].long().t() # target image, class