This commit is contained in:
Glenn Jocher 2019-12-27 09:41:19 -08:00
parent b58f41ef53
commit 1c07b1906c
1 changed files with 1 additions and 1 deletions

View File

@ -529,7 +529,7 @@ def non_max_suppression(prediction, conf_thres=0.5, iou_thres=0.5, multi_cls=Tru
box = xywh2xyxy(pred[:, :4])
# Detections matrix nx6 (xyxy, conf, cls)
if multi_cls or conf_thres < 0.01:
if multi_cls:
i, j = (pred[:, 5:] > conf_thres).nonzero().t()
pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
else: # best class only