diff --git a/train.py b/train.py index 720e0b1e..06878e6c 100644 --- a/train.py +++ b/train.py @@ -121,10 +121,9 @@ def train( for i, (imgs, targets, _, _) in enumerate(dataloader): imgs = imgs.to(device) targets = targets.to(device) - nt = len(targets) - if nt == 0: # if no targets continue - continue + # if nt == 0: # if no targets continue + # continue # Plot images with bounding boxes if epoch == 0 and i == 0: diff --git a/utils/utils.py b/utils/utils.py index cb2ea0ff..d060b9c1 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -321,7 +321,7 @@ def build_targets(model, targets): # Class tcls.append(c) - if nt: + if c.shape[0]: assert c.max() <= layer.nC, 'Target classes exceed model classes' return txy, twh, tcls, indices