[conf > conf_thres] update

This commit is contained in:
Glenn Jocher 2020-05-17 20:59:19 -07:00
parent 5b572681ff
commit 0c7d7427e4
1 changed files with 2 additions and 2 deletions

View File

@ -482,7 +482,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
for xi, x in enumerate(prediction): # image index, image inference for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints # Apply constraints
x = x[x[:, 4] > conf_thres] # confidence x = x[x[:, 4] > conf_thres] # confidence
# x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)] # width-height x = x[((x[:, 2:4] > min_wh) & (x[:, 2:4] < max_wh)).all(1)] # width-height
# If none remain process next image # If none remain process next image
if not x.shape[0]: if not x.shape[0]:
@ -500,7 +500,7 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, multi_label=T
x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) x = torch.cat((box[i], x[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
else: # best class only else: # best class only
conf, j = x[:, 5:].max(1) conf, j = x[:, 5:].max(1)
x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1) x = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)[conf > conf_thres]
# Filter by class # Filter by class
if classes: if classes: