This commit is contained in:
Glenn Jocher 2019-12-19 18:32:45 -08:00
parent fd949a8eb3
commit 674d0de170
1 changed files with 7 additions and 2 deletions

View File

@ -501,7 +501,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
pred[:, :4] = xywh2xyxy(pred[:, :4])
# Detections ordered as (x1y1x2y2, conf, cls)
# Expand
expand = False
if expand:
i, j = (pred[:, 4:] > conf_thres).nonzero().t()
pred = torch.cat((pred[i, :4], pred[i, j].unsqueeze(1), j.float().unsqueeze(1)), 1) # (x1y1x2y2, conf, cls)
else:
pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1)
# Get detections sorted by decreasing confidence scores