diff --git a/utils/utils.py b/utils/utils.py index 911d3f1a..62e2cf49 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -524,10 +524,6 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru # Apply width-height constraint pred = pred[(pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1)] - # If none remain process next image - if len(pred) == 0: - continue - # Compute conf pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf @@ -550,6 +546,10 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru if not torch.isfinite(pred).all(): pred = pred[torch.isfinite(pred).all(1)] + # If none remain process next image + if not pred.shape[0]: + continue + # Batched NMS if method == 'vision_batch': c = pred[:, 5] * 0 if agnostic else pred[:, 5] # class-agnostic NMS