nms speedup

This commit is contained in:
Glenn Jocher 2019-03-15 20:40:37 +02:00 committed by GitHub
parent c1c09eb3cc
commit 2df8d7e9f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 21 additions and 14 deletions

View File

@ -203,19 +203,22 @@ def compute_ap(recall, precision):
def bbox_iou(box1, box2, x1y1x2y2=True):
box1 = box1.t()
box2 = box2.t()
"""
Returns the IoU of two bounding boxes
"""
if x1y1x2y2:
# Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
else:
# x1, y1, w1, h1 = box1
# Transform from center and width to exact coordinates
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
# get the coordinates of the intersection rectangle
inter_rect_x1 = torch.max(b1_x1, b2_x1)
@ -353,8 +356,6 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
# multivariate_normal.pdf(x, mean=mat['class_mu'][c, :2], cov=mat['class_cov'][c, :2, :2])
class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1)
# v = ((pred[:, 4] > conf_thres) & (class_prob > .4)) # TODO examine arbitrary 0.4 thres here
v = pred[:, 4] > conf_thres
v = v.nonzero().squeeze()
if len(v.shape) == 0:
@ -389,13 +390,19 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
# Non-maximum suppression
det_max = []
ind = list(range(len(dc)))
if nms_style == 'OR': # default
while dc.shape[0]:
det_max.append(dc[:1]) # save highest conf detection
if len(dc) == 1: # Stop if we're at the last detection
break
iou = bbox_iou(det_max[-1], dc[1:]) # iou with other boxes
dc = dc[1:][iou < nms_thres] # remove ious > threshold
while len(ind):
di = dc[ind[0]:ind[0] + 1]
det_max.append(di) # save highest conf detection
reject = bbox_iou(di, dc[ind]) > nms_thres
[ind.pop(i) for i in reversed(reject.nonzero())]
# while dc.shape[0]: # SLOWER
# det_max.append(dc[:1]) # save highest conf detection
# if len(dc) == 1: # Stop if we're at the last detection
# break
# iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes
# dc = dc[1:][iou < nms_thres] # remove ious > threshold
# Image Total P R mAP
# 4964 5000 0.629 0.594 0.586