FIX: trainig fails if targets list is empty (#198)

* FIX: trainig fails if targets list is empty

* Update utils.py
This commit is contained in:
IlyaOvodov 2019-04-12 15:55:26 +03:00 committed by Glenn Jocher
parent 24f86b008a
commit 5ea92e7ee2
1 changed files with 11 additions and 10 deletions

View File

@ -288,22 +288,23 @@ def build_targets(model, targets):
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel): if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
model = model.module model = model.module
nt = len(targets)
txy, twh, tcls, indices = [], [], [], [] txy, twh, tcls, indices = [], [], [], []
for i in model.yolo_layers: for i in model.yolo_layers:
layer = model.module_list[i][0] layer = model.module_list[i][0]
# iou of targets-anchors # iou of targets-anchors
t, a = targets, []
gwh = targets[:, 4:6] * layer.nG gwh = targets[:, 4:6] * layer.nG
iou = [wh_iou(x, gwh) for x in layer.anchor_vec] if nt:
iou, a = torch.stack(iou, 0).max(0) # best iou and anchor iou = [wh_iou(x, gwh) for x in layer.anchor_vec]
iou, a = torch.stack(iou, 0).max(0) # best iou and anchor
# reject below threshold ious (OPTIONAL, increases P, lowers R) # reject below threshold ious (OPTIONAL, increases P, lowers R)
reject = True reject = True
if reject: if reject:
j = iou > 0.10 j = iou > 0.10
t, a, gwh = targets[j], a[j], gwh[j] t, a, gwh = targets[j], a[j], gwh[j]
else:
t = targets
# Indices # Indices
b, c = t[:, :2].long().t() # target image, class b, c = t[:, :2].long().t() # target image, class
@ -320,7 +321,7 @@ def build_targets(model, targets):
# Class # Class
tcls.append(c) tcls.append(c)
if c.shape[0]: if nt:
assert c.max() <= layer.nC, 'Target classes exceed model classes' assert c.max() <= layer.nC, 'Target classes exceed model classes'
return txy, twh, tcls, indices return txy, twh, tcls, indices