updates
This commit is contained in:
parent
f788a57009
commit
2ef92f5651
|
@ -369,12 +369,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
|||
if prediction.is_cuda:
|
||||
unique_labels = unique_labels.cuda(prediction.device)
|
||||
|
||||
nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental)
|
||||
nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental)
|
||||
for c in unique_labels:
|
||||
# Get the detections with class c
|
||||
dc = detections[detections[:, -1] == c]
|
||||
# Sort the detections by maximum object confidence
|
||||
_, conf_sort_index = torch.sort(dc[:, 4], descending=True)
|
||||
_, conf_sort_index = torch.sort(dc[:, 4] * dc[:, 5], descending=True)
|
||||
dc = dc[conf_sort_index]
|
||||
|
||||
# Non-maximum suppression
|
||||
|
@ -411,6 +411,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
|||
# 4964 5000 0.632 0.597 0.588 # normal
|
||||
# 4964 5000 0.632 0.597 0.588 # squared
|
||||
# 4964 5000 0.631 0.597 0.588 # sqrt
|
||||
# normal best_v1_0.pt
|
||||
|
||||
if len(det_max) > 0:
|
||||
det_max = torch.cat(det_max)
|
||||
|
|
Loading…
Reference in New Issue