merge_batch NMS method
This commit is contained in:
parent
3265d50f69
commit
aa0c64b5ac
|
@ -558,9 +558,15 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||||
boxes += c.view(-1, 1) * max_wh
|
boxes += c.view(-1, 1) * max_wh
|
||||||
if method == 'vision_batch':
|
if method == 'vision_batch':
|
||||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||||
|
elif method == 'merge_batch': # Merge NMS
|
||||||
|
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||||
|
iou = box_iou(boxes, boxes[i]).tril_() # upper triangular iou matrix
|
||||||
|
weights = (iou > conf_thres) * scores.view(-1, 1)
|
||||||
|
weights /= weights.sum(0)
|
||||||
|
pred[i, :4] = torch.matmul(weights.T, pred[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
||||||
elif method == 'fast_batch': # FastNMS from https://github.com/dbolya/yolact
|
elif method == 'fast_batch': # FastNMS from https://github.com/dbolya/yolact
|
||||||
iou = box_iou(boxes, boxes).triu_(diagonal=1) # upper triangular iou matrix
|
iou = box_iou(boxes, boxes).triu_(diagonal=1) # upper triangular iou matrix
|
||||||
i = iou.max(dim=0)[0] < iou_thres
|
i = iou.max(0)[0] < iou_thres
|
||||||
|
|
||||||
output[image_i] = pred[i]
|
output[image_i] = pred[i]
|
||||||
continue
|
continue
|
||||||
|
@ -577,10 +583,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||||
elif n > 500:
|
elif n > 500:
|
||||||
dc = dc[:500] # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117
|
dc = dc[:500] # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117
|
||||||
|
|
||||||
if method == 'vision':
|
if method == 'or': # default
|
||||||
det_max.append(dc[torchvision.ops.boxes.nms(dc[:, :4], dc[:, 4], iou_thres)])
|
|
||||||
|
|
||||||
elif method == 'or': # default
|
|
||||||
# METHOD1
|
# METHOD1
|
||||||
# ind = list(range(len(dc)))
|
# ind = list(range(len(dc)))
|
||||||
# while len(ind):
|
# while len(ind):
|
||||||
|
@ -629,7 +632,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||||
|
|
||||||
if len(det_max):
|
if len(det_max):
|
||||||
det_max = torch.cat(det_max) # concatenate
|
det_max = torch.cat(det_max) # concatenate
|
||||||
output[image_i] = det_max[(-det_max[:, 4]).argsort()] # sort
|
output[image_i] = det_max[det_max[:, 4].argsort(descending=True)] # sort
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue