diff --git a/utils/utils.py b/utils/utils.py index 68c9a82d..a7e67b0c 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -529,7 +529,7 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru box = xywh2xyxy(pred[:, :4]) # Detections matrix nx6 (xyxy, conf, cls) - if multi_cls or conf_thres < 0.01: + if multi_cls: i, j = (pred[:, 5:] > conf_thres).nonzero().t() pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) else: # best class only