merge NMS speed/memory improvements

This commit is contained in:
Glenn Jocher 2020-03-31 15:37:23 -07:00
parent 992e0d7cb4
commit f4eecef700
1 changed files with 11 additions and 13 deletions

View File

@ -304,14 +304,14 @@ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False):
return iou
def box_iou(boxes1, boxes2):
def box_iou(box1, box2):
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
boxes1 (Tensor[N, 4])
boxes2 (Tensor[M, 4])
box1 (Tensor[N, 4])
box2 (Tensor[M, 4])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
@ -321,13 +321,11 @@ def box_iou(boxes1, boxes2):
# box = 4xn
return (box[2] - box[0]) * (box[3] - box[1])
area1 = box_area(boxes1.t())
area2 = box_area(boxes2.t())
area1 = box_area(box1.t())
area2 = box_area(box2.t())
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
inter = (rb - lt).clamp(0).prod(2) # [N,M]
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
@ -509,6 +507,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
nc = prediction[0].shape[1] - 5 # number of classes
multi_label &= nc > 1 # multiple labels per box
output = [None] * len(prediction)
for xi, x in enumerate(prediction): # image index, image inference
# Apply conf constraint
x = x[x[:, 4] > conf_thres]
@ -556,10 +555,9 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
if n < 1000: # update boxes
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
weights = (iou > iou_thres) * scores.view(-1, 1)
weights /= weights.sum(0)
if n < 5000: # update boxes
weights = (box_iou(boxes, boxes).tril_() > iou_thres) * scores.view(-1, 1) # box weights
weights /= weights.sum(0) # normalize
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
elif method == 'vision':
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)