This commit is contained in:
Glenn Jocher 2019-05-02 23:56:58 +02:00
parent ae41d5855a
commit b901441e76
1 changed files with 11 additions and 0 deletions

View File

@ -430,6 +430,17 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
det_max.append(dc[:1]) det_max.append(dc[:1])
dc = dc[i == 0] dc = dc[i == 0]
elif nms_style == 'SOFT': # soft-NMS https://arxiv.org/abs/1704.04503
sigma = nms_thres # soft-nms sigma parameter
while len(dc):
if len(dc) == 1:
det_max.append(dc)
break
det_max.append(dc[:1])
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
dc = dc[1:]
dc[:, 4] *= torch.exp(-iou ** 2 / sigma) # decay confidences
if len(det_max): if len(det_max):
det_max = torch.cat(det_max) # concatenate det_max = torch.cat(det_max) # concatenate
output[image_i] = det_max[(-det_max[:, 4]).argsort()] # sort output[image_i] = det_max[(-det_max[:, 4]).argsort()] # sort