diff --git a/utils/utils.py b/utils/utils.py index bfdddeef..cb2ea0ff 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -288,22 +288,23 @@ def build_targets(model, targets): if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): model = model.module + nt = len(targets) txy, twh, tcls, indices = [], [], [], [] for i in model.yolo_layers: layer = model.module_list[i][0] # iou of targets-anchors + t, a = targets, [] gwh = targets[:, 4:6] * layer.nG - iou = [wh_iou(x, gwh) for x in layer.anchor_vec] - iou, a = torch.stack(iou, 0).max(0) # best iou and anchor + if nt: + iou = [wh_iou(x, gwh) for x in layer.anchor_vec] + iou, a = torch.stack(iou, 0).max(0) # best iou and anchor - # reject below threshold ious (OPTIONAL, increases P, lowers R) - reject = True - if reject: - j = iou > 0.10 - t, a, gwh = targets[j], a[j], gwh[j] - else: - t = targets + # reject below threshold ious (OPTIONAL, increases P, lowers R) + reject = True + if reject: + j = iou > 0.10 + t, a, gwh = targets[j], a[j], gwh[j] # Indices b, c = t[:, :2].long().t() # target image, class @@ -320,7 +321,7 @@ def build_targets(model, targets): # Class tcls.append(c) - if c.shape[0]: + if nt: assert c.max() <= layer.nC, 'Target classes exceed model classes' return txy, twh, tcls, indices