updates
This commit is contained in:
parent
ae41d5855a
commit
b901441e76
|
@ -430,6 +430,17 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
||||||
det_max.append(dc[:1])
|
det_max.append(dc[:1])
|
||||||
dc = dc[i == 0]
|
dc = dc[i == 0]
|
||||||
|
|
||||||
|
elif nms_style == 'SOFT': # soft-NMS https://arxiv.org/abs/1704.04503
|
||||||
|
sigma = nms_thres # soft-nms sigma parameter
|
||||||
|
while len(dc):
|
||||||
|
if len(dc) == 1:
|
||||||
|
det_max.append(dc)
|
||||||
|
break
|
||||||
|
det_max.append(dc[:1])
|
||||||
|
iou = bbox_iou(dc[0], dc[1:]) # iou with other boxes
|
||||||
|
dc = dc[1:]
|
||||||
|
dc[:, 4] *= torch.exp(-iou ** 2 / sigma) # decay confidences
|
||||||
|
|
||||||
if len(det_max):
|
if len(det_max):
|
||||||
det_max = torch.cat(det_max) # concatenate
|
det_max = torch.cat(det_max) # concatenate
|
||||||
output[image_i] = det_max[(-det_max[:, 4]).argsort()] # sort
|
output[image_i] = det_max[(-det_max[:, 4]).argsort()] # sort
|
||||||
|
|
Loading…
Reference in New Issue