From 4942aacef963515797bee84df89ae4b54f75e903 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 8 Dec 2019 17:19:42 -0800 Subject: [PATCH] updates --- utils/utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index c33ccc70..bfba712c 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -406,6 +406,7 @@ def build_targets(model, targets): nt = len(targets) tcls, tbox, indices, av = [], [], [], [] multi_gpu = type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) + reject, use_all_anchors = True, True for i in model.yolo_layers: # get number of grid points and anchor vec for this yolo layer if multi_gpu: @@ -419,17 +420,15 @@ def build_targets(model, targets): if nt: iou = torch.stack([wh_iou(x, gwh) for x in anchor_vec], 0) - use_best_anchor = False - if use_best_anchor: - iou, a = iou.max(0) # best iou and anchor - else: # use all anchors + if use_all_anchors: na = len(anchor_vec) # number of anchors a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1) t = targets.repeat([na, 1]) gwh = gwh.repeat([na, 1]) + else: # use best anchor only + iou, a = iou.max(0) # best iou and anchor # reject anchors below iou_thres (OPTIONAL, increases P, lowers R) - reject = True if reject: j = iou.view(-1) > model.hyp['iou_t'] # iou threshold hyperparameter t, a, gwh = t[j], a[j], gwh[j]