diff --git a/utils/utils.py b/utils/utils.py index 1348ad69..860b73ef 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -288,8 +288,6 @@ def build_targets(pred_boxes, pred_conf, pred_cls, target, anchor_wh, nA, nC, nG def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): - prediction = prediction.cpu() - """ Removes detections with lower object confidence score than 'conf_thres' and performs Non-Maximum Suppression to further filter detections. @@ -305,15 +303,17 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): # cross-class NMS cross_class_nms = False if cross_class_nms: - thresh = 0.85 + # thresh = 0.85 + thresh = nms_thres a = pred.clone() - a = a[np.argsort(-a[:, 4])] # sort best to worst + _, indices = torch.sort(-a[:, 4], 0) # sort best to worst + a = a[indices] radius = 30 # area to search for cross-class ious for i in range(len(a)): if i >= len(a) - 1: break - close = (np.abs(a[i, 0] - a[i + 1:, 0]) < radius) & (np.abs(a[i, 1] - a[i + 1:, 1]) < radius) + close = (torch.abs(a[i, 0] - a[i + 1:, 0]) < radius) & (torch.abs(a[i, 1] - a[i + 1:, 1]) < radius) close = close.nonzero() if len(close) > 0: @@ -327,10 +327,11 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): a = a[mask] pred = a - x, y, w, h = pred[:, 0].numpy(), pred[:, 1].numpy(), pred[:, 2].numpy(), pred[:, 3].numpy() + x, y, w, h = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3] a = w * h # area ar = w / (h + 1e-16) # aspect ratio - log_w, log_h, log_a, log_ar = np.log(w), np.log(h), np.log(a), np.log(ar) + + log_w, log_h, log_a, log_ar = torch.log(w), torch.log(h), torch.log(a), torch.log(ar) # n = len(w) # shape_likelihood = np.zeros((n, 60), dtype=np.float32) @@ -341,8 +342,10 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1) - v = ((pred[:, 4] > conf_thres) & (class_prob > .3)).numpy() - v = v.nonzero() + v = ((pred[:, 4] > conf_thres) & (class_prob > .3)) + v = v.nonzero().squeeze() + if len(v.shape) == 0: + v = v.unsqueeze(0) pred = pred[v] class_prob = class_prob[v] @@ -366,7 +369,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): # Iterate through all predicted classes unique_labels = detections[:, -1].cpu().unique() if prediction.is_cuda: - unique_labels = unique_labels.cuda() + unique_labels = unique_labels.cuda(prediction.device) nms_style = 'OR' # 'AND' or 'OR' (classical) for c in unique_labels: