From 2ef92f565158ec92c3d134239a0d3030c23b99a0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 18 Feb 2019 19:44:15 +0100 Subject: [PATCH] updates --- utils/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index 164dfc9e..3992b085 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -369,12 +369,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): if prediction.is_cuda: unique_labels = unique_labels.cuda(prediction.device) - nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental) + nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental) for c in unique_labels: # Get the detections with class c dc = detections[detections[:, -1] == c] # Sort the detections by maximum object confidence - _, conf_sort_index = torch.sort(dc[:, 4], descending=True) + _, conf_sort_index = torch.sort(dc[:, 4] * dc[:, 5], descending=True) dc = dc[conf_sort_index] # Non-maximum suppression @@ -411,6 +411,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): # 4964 5000 0.632 0.597 0.588 # normal # 4964 5000 0.632 0.597 0.588 # squared # 4964 5000 0.631 0.597 0.588 # sqrt + # normal best_v1_0.pt if len(det_max) > 0: det_max = torch.cat(det_max)