updates
This commit is contained in:
parent
036e3b3253
commit
41d55d452b
|
@ -374,7 +374,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
||||||
if prediction.is_cuda:
|
if prediction.is_cuda:
|
||||||
unique_labels = unique_labels.cuda(prediction.device)
|
unique_labels = unique_labels.cuda(prediction.device)
|
||||||
|
|
||||||
nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental)
|
nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental)
|
||||||
for c in unique_labels:
|
for c in unique_labels:
|
||||||
# Get the detections with class c
|
# Get the detections with class c
|
||||||
dc = detections[detections[:, -1] == c]
|
dc = detections[detections[:, -1] == c]
|
||||||
|
@ -384,7 +384,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
||||||
|
|
||||||
# Non-maximum suppression
|
# Non-maximum suppression
|
||||||
det_max = []
|
det_max = []
|
||||||
if nms_style == 'MERGE': # default
|
if nms_style == 'OR': # default
|
||||||
while dc.shape[0]:
|
while dc.shape[0]:
|
||||||
det_max.append(dc[:1]) # save highest conf detection
|
det_max.append(dc[:1]) # save highest conf detection
|
||||||
if len(dc) == 1: # Stop if we're at the last detection
|
if len(dc) == 1: # Stop if we're at the last detection
|
||||||
|
|
Loading…
Reference in New Issue