diff --git a/utils/utils.py b/utils/utils.py index ad0ca47d..b0047217 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -430,6 +430,17 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): det_max.append(dc[:1]) dc = dc[i == 0] + elif nms_style == 'SOFT': # soft-NMS https://arxiv.org/abs/1704.04503 + sigma = nms_thres # soft-nms sigma parameter + while len(dc): + if len(dc) == 1: + det_max.append(dc) + break + det_max.append(dc[:1]) + iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes + dc = dc[1:] + dc[:, 4] *= torch.exp(-iou ** 2 / sigma) # decay confidences + if len(det_max): det_max = torch.cat(det_max) # concatenate output[image_i] = det_max[(-det_max[:, 4]).argsort()] # sort