merge NMS speed/memory improvements
This commit is contained in:
parent
992e0d7cb4
commit
f4eecef700
|
@ -304,14 +304,14 @@ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False):
|
||||||
return iou
|
return iou
|
||||||
|
|
||||||
|
|
||||||
def box_iou(boxes1, boxes2):
|
def box_iou(box1, box2):
|
||||||
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
||||||
"""
|
"""
|
||||||
Return intersection-over-union (Jaccard index) of boxes.
|
Return intersection-over-union (Jaccard index) of boxes.
|
||||||
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
||||||
Arguments:
|
Arguments:
|
||||||
boxes1 (Tensor[N, 4])
|
box1 (Tensor[N, 4])
|
||||||
boxes2 (Tensor[M, 4])
|
box2 (Tensor[M, 4])
|
||||||
Returns:
|
Returns:
|
||||||
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
||||||
IoU values for every element in boxes1 and boxes2
|
IoU values for every element in boxes1 and boxes2
|
||||||
|
@ -321,13 +321,11 @@ def box_iou(boxes1, boxes2):
|
||||||
# box = 4xn
|
# box = 4xn
|
||||||
return (box[2] - box[0]) * (box[3] - box[1])
|
return (box[2] - box[0]) * (box[3] - box[1])
|
||||||
|
|
||||||
area1 = box_area(boxes1.t())
|
area1 = box_area(box1.t())
|
||||||
area2 = box_area(boxes2.t())
|
area2 = box_area(box2.t())
|
||||||
|
|
||||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
||||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
|
||||||
|
|
||||||
inter = (rb - lt).clamp(0).prod(2) # [N,M]
|
|
||||||
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
|
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
|
||||||
|
|
||||||
|
|
||||||
|
@ -509,6 +507,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||||
nc = prediction[0].shape[1] - 5 # number of classes
|
nc = prediction[0].shape[1] - 5 # number of classes
|
||||||
multi_label &= nc > 1 # multiple labels per box
|
multi_label &= nc > 1 # multiple labels per box
|
||||||
output = [None] * len(prediction)
|
output = [None] * len(prediction)
|
||||||
|
|
||||||
for xi, x in enumerate(prediction): # image index, image inference
|
for xi, x in enumerate(prediction): # image index, image inference
|
||||||
# Apply conf constraint
|
# Apply conf constraint
|
||||||
x = x[x[:, 4] > conf_thres]
|
x = x[x[:, 4] > conf_thres]
|
||||||
|
@ -556,10 +555,9 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
||||||
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
|
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
|
||||||
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
||||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||||
if n < 1000: # update boxes
|
if n < 5000: # update boxes
|
||||||
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
|
weights = (box_iou(boxes, boxes).tril_() > iou_thres) * scores.view(-1, 1) # box weights
|
||||||
weights = (iou > iou_thres) * scores.view(-1, 1)
|
weights /= weights.sum(0) # normalize
|
||||||
weights /= weights.sum(0)
|
|
||||||
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
||||||
elif method == 'vision':
|
elif method == 'vision':
|
||||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||||
|
|
Loading…
Reference in New Issue