nms speedup
This commit is contained in:
parent
c1c09eb3cc
commit
2df8d7e9f6
|
@ -203,19 +203,22 @@ def compute_ap(recall, precision):
|
|||
|
||||
|
||||
def bbox_iou(box1, box2, x1y1x2y2=True):
|
||||
box1 = box1.t()
|
||||
box2 = box2.t()
|
||||
"""
|
||||
Returns the IoU of two bounding boxes
|
||||
"""
|
||||
if x1y1x2y2:
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
|
||||
else:
|
||||
# x1, y1, w1, h1 = box1
|
||||
# Transform from center and width to exact coordinates
|
||||
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
|
||||
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
|
||||
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
|
||||
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
|
||||
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
|
||||
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
|
||||
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
|
||||
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
|
||||
|
||||
# get the coordinates of the intersection rectangle
|
||||
inter_rect_x1 = torch.max(b1_x1, b2_x1)
|
||||
|
@ -353,8 +356,6 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
|||
# multivariate_normal.pdf(x, mean=mat['class_mu'][c, :2], cov=mat['class_cov'][c, :2, :2])
|
||||
|
||||
class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1)
|
||||
|
||||
# v = ((pred[:, 4] > conf_thres) & (class_prob > .4)) # TODO examine arbitrary 0.4 thres here
|
||||
v = pred[:, 4] > conf_thres
|
||||
v = v.nonzero().squeeze()
|
||||
if len(v.shape) == 0:
|
||||
|
@ -389,13 +390,19 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
|||
|
||||
# Non-maximum suppression
|
||||
det_max = []
|
||||
ind = list(range(len(dc)))
|
||||
if nms_style == 'OR': # default
|
||||
while dc.shape[0]:
|
||||
det_max.append(dc[:1]) # save highest conf detection
|
||||
if len(dc) == 1: # Stop if we're at the last detection
|
||||
break
|
||||
iou = bbox_iou(det_max[-1], dc[1:]) # iou with other boxes
|
||||
dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
||||
while len(ind):
|
||||
di = dc[ind[0]:ind[0] + 1]
|
||||
det_max.append(di) # save highest conf detection
|
||||
reject = bbox_iou(di, dc[ind]) > nms_thres
|
||||
[ind.pop(i) for i in reversed(reject.nonzero())]
|
||||
# while dc.shape[0]: # SLOWER
|
||||
# det_max.append(dc[:1]) # save highest conf detection
|
||||
# if len(dc) == 1: # Stop if we're at the last detection
|
||||
# break
|
||||
# iou = bbox_iou(dc[:1], dc[1:]) # iou with other boxes
|
||||
# dc = dc[1:][iou < nms_thres] # remove ious > threshold
|
||||
|
||||
# Image Total P R mAP
|
||||
# 4964 5000 0.629 0.594 0.586
|
||||
|
|
Loading…
Reference in New Issue