merge NMS speed/memory improvements
This commit is contained in:
parent
992e0d7cb4
commit
f4eecef700
|
@ -304,14 +304,14 @@ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False):
|
|||
return iou
|
||||
|
||||
|
||||
def box_iou(boxes1, boxes2):
|
||||
def box_iou(box1, box2):
|
||||
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
|
||||
"""
|
||||
Return intersection-over-union (Jaccard index) of boxes.
|
||||
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
||||
Arguments:
|
||||
boxes1 (Tensor[N, 4])
|
||||
boxes2 (Tensor[M, 4])
|
||||
box1 (Tensor[N, 4])
|
||||
box2 (Tensor[M, 4])
|
||||
Returns:
|
||||
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
||||
IoU values for every element in boxes1 and boxes2
|
||||
|
@ -321,13 +321,11 @@ def box_iou(boxes1, boxes2):
|
|||
# box = 4xn
|
||||
return (box[2] - box[0]) * (box[3] - box[1])
|
||||
|
||||
area1 = box_area(boxes1.t())
|
||||
area2 = box_area(boxes2.t())
|
||||
area1 = box_area(box1.t())
|
||||
area2 = box_area(box2.t())
|
||||
|
||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
inter = (rb - lt).clamp(0).prod(2) # [N,M]
|
||||
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
||||
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
|
||||
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
|
||||
|
||||
|
||||
|
@ -509,6 +507,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
|||
nc = prediction[0].shape[1] - 5 # number of classes
|
||||
multi_label &= nc > 1 # multiple labels per box
|
||||
output = [None] * len(prediction)
|
||||
|
||||
for xi, x in enumerate(prediction): # image index, image inference
|
||||
# Apply conf constraint
|
||||
x = x[x[:, 4] > conf_thres]
|
||||
|
@ -556,10 +555,9 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
|
|||
boxes, scores = x[:, :4].clone() + c.view(-1, 1) * max_wh, x[:, 4] # boxes (offset by class), scores
|
||||
if method == 'merge': # Merge NMS (boxes merged using weighted mean)
|
||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||
if n < 1000: # update boxes
|
||||
iou = box_iou(boxes, boxes).tril_() # lower triangular iou matrix
|
||||
weights = (iou > iou_thres) * scores.view(-1, 1)
|
||||
weights /= weights.sum(0)
|
||||
if n < 5000: # update boxes
|
||||
weights = (box_iou(boxes, boxes).tril_() > iou_thres) * scores.view(-1, 1) # box weights
|
||||
weights /= weights.sum(0) # normalize
|
||||
x[:, :4] = torch.mm(weights.T, x[:, :4]) # merged_boxes(n,4) = weights(n,n) * boxes(n,4)
|
||||
elif method == 'vision':
|
||||
i = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
|
||||
|
|
Loading…
Reference in New Issue