diff --git a/utils/utils.py b/utils/utils.py index 91dc62a9..4a7e969a 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -549,7 +549,7 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru # Batched NMS if method == 'vision_batch': - c = j * 0 if agnostic else j # class-agnostic NMS + c = pred[:, 5] * 0 if agnostic else pred[:, 5] # class-agnostic NMS output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], c, iou_thres)] continue