diff --git a/utils/utils.py b/utils/utils.py index 164dfc9e..3992b085 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -369,12 +369,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): if prediction.is_cuda: unique_labels = unique_labels.cuda(prediction.device) - nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental) + nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental) for c in unique_labels: # Get the detections with class c dc = detections[detections[:, -1] == c] # Sort the detections by maximum object confidence - _, conf_sort_index = torch.sort(dc[:, 4], descending=True) + _, conf_sort_index = torch.sort(dc[:, 4] * dc[:, 5], descending=True) dc = dc[conf_sort_index] # Non-maximum suppression @@ -411,6 +411,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): # 4964 5000 0.632 0.597 0.588 # normal # 4964 5000 0.632 0.597 0.588 # squared # 4964 5000 0.631 0.597 0.588 # sqrt + # normal best_v1_0.pt if len(det_max) > 0: det_max = torch.cat(det_max)