diff --git a/utils/utils.py b/utils/utils.py index 57d16b36..40ab7c68 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -332,16 +332,25 @@ def build_targets(model, targets): # iou of targets-anchors t, a = targets, [] - gwh = targets[:, 4:6] * layer.ng + gwh = t[:, 4:6] * layer.ng if nt: - iou = [wh_iou(x, gwh) for x in layer.anchor_vec] - iou, a = torch.stack(iou, 0).max(0) # best iou and anchor + iou = torch.stack([wh_iou(x, gwh) for x in layer.anchor_vec], 0) + + use_best = True + if use_best: + iou, a = iou.max(0) # best iou and anchor + else: + na = len(layer.anchor_vec) # number of anchors + a = torch.arange(na).view((-1, 1)).repeat([1, nt]).view(-1) + t = targets.repeat([na, 1]) + gwh = gwh.repeat([na, 1]) + iou = iou.view(-1) # use all ious # reject below threshold ious (OPTIONAL, increases P, lowers R) reject = True if reject: j = iou > iou_thres - t, a, gwh = targets[j], a[j], gwh[j] + t, a, gwh = t[j], a[j], gwh[j] # Indices b, c = t[:, :2].long().t() # target image, class