FIX: trainig fails if targets list is empty (#198)
* FIX: trainig fails if targets list is empty * Update utils.py
This commit is contained in:
parent
24f86b008a
commit
5ea92e7ee2
|
@ -288,12 +288,15 @@ def build_targets(model, targets):
|
||||||
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
if type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel):
|
||||||
model = model.module
|
model = model.module
|
||||||
|
|
||||||
|
nt = len(targets)
|
||||||
txy, twh, tcls, indices = [], [], [], []
|
txy, twh, tcls, indices = [], [], [], []
|
||||||
for i in model.yolo_layers:
|
for i in model.yolo_layers:
|
||||||
layer = model.module_list[i][0]
|
layer = model.module_list[i][0]
|
||||||
|
|
||||||
# iou of targets-anchors
|
# iou of targets-anchors
|
||||||
|
t, a = targets, []
|
||||||
gwh = targets[:, 4:6] * layer.nG
|
gwh = targets[:, 4:6] * layer.nG
|
||||||
|
if nt:
|
||||||
iou = [wh_iou(x, gwh) for x in layer.anchor_vec]
|
iou = [wh_iou(x, gwh) for x in layer.anchor_vec]
|
||||||
iou, a = torch.stack(iou, 0).max(0) # best iou and anchor
|
iou, a = torch.stack(iou, 0).max(0) # best iou and anchor
|
||||||
|
|
||||||
|
@ -302,8 +305,6 @@ def build_targets(model, targets):
|
||||||
if reject:
|
if reject:
|
||||||
j = iou > 0.10
|
j = iou > 0.10
|
||||||
t, a, gwh = targets[j], a[j], gwh[j]
|
t, a, gwh = targets[j], a[j], gwh[j]
|
||||||
else:
|
|
||||||
t = targets
|
|
||||||
|
|
||||||
# Indices
|
# Indices
|
||||||
b, c = t[:, :2].long().t() # target image, class
|
b, c = t[:, :2].long().t() # target image, class
|
||||||
|
@ -320,7 +321,7 @@ def build_targets(model, targets):
|
||||||
|
|
||||||
# Class
|
# Class
|
||||||
tcls.append(c)
|
tcls.append(c)
|
||||||
if c.shape[0]:
|
if nt:
|
||||||
assert c.max() <= layer.nC, 'Target classes exceed model classes'
|
assert c.max() <= layer.nC, 'Target classes exceed model classes'
|
||||||
|
|
||||||
return txy, twh, tcls, indices
|
return txy, twh, tcls, indices
|
||||||
|
|
Loading…
Reference in New Issue