From 5ea92e7ee2ad73fe891e0586ac9bf5b0c35f84e1 Mon Sep 17 00:00:00 2001 From: IlyaOvodov <34230114+IlyaOvodov@users.noreply.github.com> Date: Fri, 12 Apr 2019 15:55:26 +0300 Subject: [PATCH] FIX: trainig fails if targets list is empty (#198) * FIX: trainig fails if targets list is empty * Update utils.py --- utils/utils.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index bfdddeef..cb2ea0ff 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -288,22 +288,23 @@ def build_targets(model, targets): if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): model = model.module + nt = len(targets) txy, twh, tcls, indices = [], [], [], [] for i in model.yolo_layers: layer = model.module_list[i][0] # iou of targets-anchors + t, a = targets, [] gwh = targets[:, 4:6] * layer.nG - iou = [wh_iou(x, gwh) for x in layer.anchor_vec] - iou, a = torch.stack(iou, 0).max(0) # best iou and anchor + if nt: + iou = [wh_iou(x, gwh) for x in layer.anchor_vec] + iou, a = torch.stack(iou, 0).max(0) # best iou and anchor - # reject below threshold ious (OPTIONAL, increases P, lowers R) - reject = True - if reject: - j = iou > 0.10 - t, a, gwh = targets[j], a[j], gwh[j] - else: - t = targets + # reject below threshold ious (OPTIONAL, increases P, lowers R) + reject = True + if reject: + j = iou > 0.10 + t, a, gwh = targets[j], a[j], gwh[j] # Indices b, c = t[:, :2].long().t() # target image, class @@ -320,7 +321,7 @@ def build_targets(model, targets): # Class tcls.append(c) - if c.shape[0]: + if nt: assert c.max() <= layer.nC, 'Target classes exceed model classes' return txy, twh, tcls, indices