From 2df8d7e9f6fd3a3e0233029bf39d4db66807a229 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 15 Mar 2019 20:40:37 +0200 Subject: [PATCH] nms speedup --- utils/utils.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index ebbea4d3..8e54f7cb 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -203,19 +203,22 @@ def compute_ap(recall, precision): def bbox_iou(box1, box2, x1y1x2y2=True): + box1 = box1.t() + box2 = box2.t() """ Returns the IoU of two bounding boxes """ if x1y1x2y2: # Get the coordinates of bounding boxes - b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] - b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] else: + # x1, y1, w1, h1 = box1 # Transform from center and width to exact coordinates - b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 - b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 - b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 - b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 + b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 + b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 + b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 + b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 # get the coordinates of the intersection rectangle inter_rect_x1 = torch.max(b1_x1, b2_x1) @@ -353,8 +356,6 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): # multivariate_normal.pdf(x, mean=mat['class_mu'][c, :2], cov=mat['class_cov'][c, :2, :2]) class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1) - - # v = ((pred[:, 4] > conf_thres) & (class_prob > .4)) # TODO examine arbitrary 0.4 thres here v = pred[:, 4] > conf_thres v = v.nonzero().squeeze() if len(v.shape) == 0: @@ -389,13 +390,19 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): # Non-maximum suppression det_max = [] + ind = list(range(len(dc))) if nms_style == 'OR': # default - while dc.shape[0]: - det_max.append(dc[:1]) # save highest conf detection - if len(dc) == 1: # Stop if we're at the last detection - break - iou = bbox_iou(det_max[-1], dc[1:]) # iou with other boxes - dc = dc[1:][iou < nms_thres] # remove ious > threshold + while len(ind): + di = dc[ind[0]:ind[0] + 1] + det_max.append(di) # save highest conf detection + reject = bbox_iou(di, dc[ind]) > nms_thres + [ind.pop(i) for i in reversed(reject.nonzero())] + # while dc.shape[0]: # SLOWER + # det_max.append(dc[:1]) # save highest conf detection + # if len(dc) == 1: # Stop if we're at the last detection + # break + # iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes + # dc = dc[1:][iou < nms_thres] # remove ious > threshold # Image Total P R mAP # 4964 5000 0.629 0.594 0.586