diff --git a/utils/utils.py b/utils/utils.py index 3b72bbbc..e8c83613 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -506,15 +506,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru continue # Box (center x, center y, width, height) to (x1, y1, x2, y2) - pred[:, :4] = xywh2xyxy(pred[:, :4]) + box = xywh2xyxy(pred[:, :4]) - # Multi-class + # Detections matrix nx6 (xyxy, conf, cls) if multi_cls or conf_thres < 0.01: i, j = (pred[:, 5:] > conf_thres).nonzero().t() - pred = torch.cat((pred[i, :4], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) + pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1) else: # best class only conf, j = pred[:, 5:].max(1) - pred = torch.cat((pred[:, :4], conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # (xyxy, conf, cls) + pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # Apply finite constraint pred = pred[torch.isfinite(pred).all(1)]