This commit is contained in:
Glenn Jocher 2019-02-18 19:44:15 +01:00
parent f788a57009
commit 2ef92f5651
1 changed files with 3 additions and 2 deletions

View File

@ -369,12 +369,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
if prediction.is_cuda: if prediction.is_cuda:
unique_labels = unique_labels.cuda(prediction.device) unique_labels = unique_labels.cuda(prediction.device)
nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental) nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental)
for c in unique_labels: for c in unique_labels:
# Get the detections with class c # Get the detections with class c
dc = detections[detections[:, -1] == c] dc = detections[detections[:, -1] == c]
# Sort the detections by maximum object confidence # Sort the detections by maximum object confidence
_, conf_sort_index = torch.sort(dc[:, 4], descending=True) _, conf_sort_index = torch.sort(dc[:, 4] * dc[:, 5], descending=True)
dc = dc[conf_sort_index] dc = dc[conf_sort_index]
# Non-maximum suppression # Non-maximum suppression
@ -411,6 +411,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
# 4964 5000 0.632 0.597 0.588 # normal # 4964 5000 0.632 0.597 0.588 # normal
# 4964 5000 0.632 0.597 0.588 # squared # 4964 5000 0.632 0.597 0.588 # squared
# 4964 5000 0.631 0.597 0.588 # sqrt # 4964 5000 0.631 0.597 0.588 # sqrt
# normal best_v1_0.pt
if len(det_max) > 0: if len(det_max) > 0:
det_max = torch.cat(det_max) det_max = torch.cat(det_max)