diff --git a/utils/utils.py b/utils/utils.py index 2e54870e..54adc05e 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -501,8 +501,13 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): # Box (center x, center y, width, height) to (x1, y1, x2, y2) pred[:, :4] = xywh2xyxy(pred[:, :4]) - # Detections ordered as (x1y1x2y2, conf, cls) - pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) + # Expand + expand = False + if expand: + i, j = (pred[:, 4:] > conf_thres).nonzero().t() + pred = torch.cat((pred[i, :4], pred[i, j].unsqueeze(1), j.float().unsqueeze(1)), 1) # (x1y1x2y2, conf, cls) + else: + pred = torch.cat((pred[:, :4], conf[i].unsqueeze(1), cls[i].unsqueeze(1).float()), 1) # Get detections sorted by decreasing confidence scores pred = pred[(-pred[:, 4]).argsort()]