diff --git a/utils/utils.py b/utils/utils.py index f3403120..b2dad445 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -501,7 +501,7 @@ def build_targets(model, targets): return tcls, tbox, indices, av -def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_cls=True, classes=None, agnostic=False): +def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=True, classes=None, agnostic=False): """ Removes detections with lower object confidence score than 'conf_thres' Non-Maximum Suppression to further filter detections. @@ -516,7 +516,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_cls=Tru method = 'vision_batch' batched = 'batch' in method # run once per image, all classes simultaneously nc = prediction[0].shape[1] - 5 # number of classes - multi_cls = multi_cls and (nc > 1) # allow multiple classes per anchor + multi_label &= nc > 1 # multiple labels per box output = [None] * len(prediction) for image_i, pred in enumerate(prediction): # Apply conf constraint @@ -536,7 +536,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_cls=Tru box = xywh2xyxy(pred[:, :4]) # Detections matrix nx6 (xyxy, conf, cls) - if multi_cls: + if multi_label: i, j = (pred[:, 5:] > conf_thres).nonzero().t() pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) else: # best class only