This commit is contained in:
Glenn Jocher 2019-02-18 19:52:38 +01:00
parent 0f06fbd681
commit bbb750876e
1 changed files with 3 additions and 3 deletions

View File

@ -345,7 +345,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
class_prob, class_pred = torch.max(F.softmax(pred[:, 5:], 1), 1)
v = (pred[:, 4] > (conf_thres * class_prob)) # TODO examine arbitrary 0.3 thres here
v = ((pred[:, 4] > conf_thres) & (class_prob > .1)) # TODO examine arbitrary 0.3 thres here
v = v.nonzero().squeeze()
if len(v.shape) == 0:
v = v.unsqueeze(0)
@ -369,7 +369,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
if prediction.is_cuda:
unique_labels = unique_labels.cuda(prediction.device)
nms_style = 'OR' # 'OR' (default), 'AND', 'MERGE' (experimental)
nms_style = 'MERGE' # 'OR' (default), 'AND', 'MERGE' (experimental)
for c in unique_labels:
# Get the detections with class c
dc = detections[detections[:, -1] == c]
@ -389,7 +389,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
# Image Total P R mAP
# 5000 5000 0.627 0.593 0.584
# 4964 5000 0.629 0.594 0.586 # complete probability sort
# 4964 5000 0.629 0.594 0.586 # complete probability sort
elif nms_style == 'AND': # requires overlap, single boxes erased