diff --git a/utils/utils.py b/utils/utils.py index 312b8a95..852f3d83 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -539,14 +539,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru # Apply finite constraint pred = pred[torch.isfinite(pred).all(1)] - # Get detections sorted by decreasing confidence scores - pred = pred[pred[:, 4].argsort(descending=True)] - # Batched NMS if method == 'vision_batch': output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], nms_thres)] continue + # Sort by confidence + if not method.startswith('vision'): + pred = pred[pred[:, 4].argsort(descending=True)] + # All other NMS methods det_max = [] cls = pred[:, -1]