Fixed NMS bug causing big CPU usage. Note that using 'cross_class_nms' still takes a huge amount of time and should be fixed somehow.

This commit is contained in:
Nir Ben-Zvi 2018-11-22 15:36:14 +02:00
parent a46e500f9e
commit d41f85702d
1 changed files with 13 additions and 10 deletions

View File

@ -285,8 +285,6 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
prediction = prediction.cpu()
""" """
Removes detections with lower object confidence score than 'conf_thres' and performs Removes detections with lower object confidence score than 'conf_thres' and performs
Non-Maximum Suppression to further filter detections. Non-Maximum Suppression to further filter detections.
@ -302,15 +300,17 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
# cross-class NMS # cross-class NMS
cross_class_nms = False cross_class_nms = False
if cross_class_nms: if cross_class_nms:
thresh = 0.85 # thresh = 0.85
thresh = nms_thres
a = pred.clone() a = pred.clone()
a = a[np.argsort(-a[:, 4])] # sort best to worst _, indices = torch.sort(-a[:, 4], 0) # sort best to worst
a = a[indices]
radius = 30 # area to search for cross-class ious radius = 30 # area to search for cross-class ious
for i in range(len(a)): for i in range(len(a)):
if i >= len(a) - 1: if i >= len(a) - 1:
break break
close = (np.abs(a[i, 0] - a[i + 1:, 0]) < radius) & (np.abs(a[i, 1] - a[i + 1:, 1]) < radius) close = (torch.abs(a[i, 0] - a[i + 1:, 0]) < radius) & (torch.abs(a[i, 1] - a[i + 1:, 1]) < radius)
close = close.nonzero() close = close.nonzero()
if len(close) > 0: if len(close) > 0:
@ -324,10 +324,11 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
a = a[mask] a = a[mask]
pred = a pred = a
x, y, w, h = pred[:, 0].numpy(), pred[:, 1].numpy(), pred[:, 2].numpy(), pred[:, 3].numpy() x, y, w, h = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
a = w * h # area a = w * h # area
ar = w / (h + 1e-16) # aspect ratio ar = w / (h + 1e-16) # aspect ratio
log_w, log_h, log_a, log_ar = np.log(w), np.log(h), np.log(a), np.log(ar)
log_w, log_h, log_a, log_ar = torch.log(w), torch.log(h), torch.log(a), torch.log(ar)
# n = len(w) # n = len(w)
# shape_likelihood = np.zeros((n, 60), dtype=np.float32) # shape_likelihood = np.zeros((n, 60), dtype=np.float32)
@ -338,8 +339,10 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1) class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1)
v = ((pred[:, 4] > conf_thres) & (class_prob > .3)).numpy() v = ((pred[:, 4] > conf_thres) & (class_prob > .3))
v = v.nonzero() v = v.nonzero().squeeze()
if len(v.shape) == 0:
v = v.unsqueeze(0)
pred = pred[v] pred = pred[v]
class_prob = class_prob[v] class_prob = class_prob[v]
@ -363,7 +366,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
# Iterate through all predicted classes # Iterate through all predicted classes
unique_labels = detections[:, -1].cpu().unique() unique_labels = detections[:, -1].cpu().unique()
if prediction.is_cuda: if prediction.is_cuda:
unique_labels = unique_labels.cuda() unique_labels = unique_labels.cuda(prediction.device)
nms_style = 'OR' # 'AND' or 'OR' (classical) nms_style = 'OR' # 'AND' or 'OR' (classical)
for c in unique_labels: for c in unique_labels: