merge NMS speed/memory improvements

This commit is contained in:
Glenn Jocher 2020-03-31 15:37:23 -07:00
parent 992e0d7cb4
commit f4eecef700
1 changed files with 11 additions and 13 deletions

View File

@ -304,14 +304,14 @@ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False):
return iou return iou
def box_iou(boxes1, boxes2): def box_iou(box1, box2):
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
""" """
Return intersection-over-union (Jaccard index) of boxes. Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments: Arguments:
boxes1 (Tensor[N, 4]) box1 (Tensor[N, 4])
boxes2 (Tensor[M, 4]) box2 (Tensor[M, 4])
Returns: Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2 IoU values for every element in boxes1 and boxes2
@ -321,13 +321,11 @@ def box_iou(boxes1, boxes2):
# box = 4xn # box = 4xn
return (box[2] - box[0]) * (box[3] - box[1]) return (box[2] - box[0]) * (box[3] - box[1])
area1 = box_area(boxes1.t()) area1 = box_area(box1.t())
area2 = box_area(boxes2.t()) area2 = box_area(box2.t())
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
inter = (rb - lt).clamp(0).prod(2) # [N,M]
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter) return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
@ -509,6 +507,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
nc = prediction[0].shape[1] - 5 # number of classes nc = prediction[0].shape[1] - 5 # number of classes
multi_label &= nc > 1 # multiple labels per box multi_label &= nc > 1 # multiple labels per box
output = [None] * len(prediction) output = [None] * len(prediction)
for xi, x in enumerate(prediction): # image index, image inference for xi, x in enumerate(prediction): # image index, image inference
# Apply conf constraint # Apply conf constraint
x = x[x[:, 4] > conf_thres] x = x[x[:, 4] > conf_thres]
@ -556,10 +555,9 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
if method == 'merge': # Merge NMS (boxes merged using weighted mean) if method == 'merge': # Merge NMS (boxes merged using weighted mean)
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
if n < 1000: # update boxes if n < 5000: # update boxes
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix weights = (box_iou(boxes, boxes).tril_() > iou_thres) * scores.view(-1, 1) # box weights
weights = (iou > iou_thres) * scores.view(-1, 1) weights /= weights.sum(0) # normalize
weights /= weights.sum(0)
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4) x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
elif method == 'vision': elif method == 'vision':
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres) i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)