diff --git a/utils/utils.py b/utils/utils.py index 6ffdfee8..6abd2b30 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -459,7 +459,7 @@ def build_targets(model, targets): return tcls, tbox, indices, av -def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=False, method='vision'): +def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=True, method='vision'): """ Removes detections with lower object confidence score than 'conf_thres' Non-Maximum Suppression to further filter detections. @@ -489,7 +489,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Fal pred[:, :4] = xywh2xyxy(pred[:, :4]) # Multi-class - if multi_cls: + if multi_cls or conf_thres < 0.01: i, j = (pred[:, 4:] > conf_thres).nonzero().t() pred = torch.cat((pred[i, :4], pred[i, j + 4].unsqueeze(1), j.float().unsqueeze(1)), 1) else: # best class only