diff --git a/utils/utils.py b/utils/utils.py index 6abd2b30..fc9abc67 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -501,10 +501,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru # Batched NMS if method == 'vision_batch': - i = torchvision.ops.boxes.batched_nms(boxes=pred[:, :4], - scores=pred[:, 4], - idxs=pred[:, 6], - iou_threshold=nms_thres) + i = torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], nms_thres) output[image_i] = pred[i] continue