updates
This commit is contained in:
parent
ae41d5855a
commit
b901441e76
|
@ -430,6 +430,17 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
|||
det_max.append(dc[:1])
|
||||
dc = dc[i == 0]
|
||||
|
||||
elif nms_style == 'SOFT': # soft-NMS https://arxiv.org/abs/1704.04503
|
||||
sigma = nms_thres # soft-nms sigma parameter
|
||||
while len(dc):
|
||||
if len(dc) == 1:
|
||||
det_max.append(dc)
|
||||
break
|
||||
det_max.append(dc[:1])
|
||||
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
||||
dc = dc[1:]
|
||||
dc[:, 4] *= torch.exp(-iou ** 2 / sigma) # decay confidences
|
||||
|
||||
if len(det_max):
|
||||
det_max = torch.cat(det_max) # concatenate
|
||||
output[image_i] = det_max[(-det_max[:, 4]).argsort()] # sort
|
||||
|
|
Loading…
Reference in New Issue