This commit is contained in:
Glenn Jocher 2019-12-08 17:19:42 -08:00
parent b913d1ab55
commit 4942aacef9
1 changed files with 4 additions and 5 deletions

View File

@ -406,6 +406,7 @@ def build_targets(model, targets):
nt = len(targets) nt = len(targets)
tcls, tbox, indices, av = [], [], [], [] tcls, tbox, indices, av = [], [], [], []
multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
reject, use_all_anchors = True, True
for i in model.yolo_layers: for i in model.yolo_layers:
# get number of grid points and anchor vec for this yolo layer # get number of grid points and anchor vec for this yolo layer
if multi_gpu: if multi_gpu:
@ -419,17 +420,15 @@ def build_targets(model, targets):
if nt: if nt:
iou = torch.stack([wh_iou(x, gwh) for x in anchor_vec], 0) iou = torch.stack([wh_iou(x, gwh) for x in anchor_vec], 0)
use_best_anchor = False if use_all_anchors:
if use_best_anchor:
iou, a = iou.max(0) # best iou and anchor
else: # use all anchors
na = len(anchor_vec) # number of anchors na = len(anchor_vec) # number of anchors
a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1) a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1)
t = targets.repeat([na, 1]) t = targets.repeat([na, 1])
gwh = gwh.repeat([na, 1]) gwh = gwh.repeat([na, 1])
else: # use best anchor only
iou, a = iou.max(0) # best iou and anchor
# reject anchors below iou_thres (OPTIONAL, increases P, lowers R) # reject anchors below iou_thres (OPTIONAL, increases P, lowers R)
reject = True
if reject: if reject:
j = iou.view(-1) > model.hyp['iou_t'] # iou threshold hyperparameter j = iou.view(-1) > model.hyp['iou_t'] # iou threshold hyperparameter
t, a, gwh = t[j], a[j], gwh[j] t, a, gwh = t[j], a[j], gwh[j]