diff --git a/utils/utils.py b/utils/utils.py index b9180e13..3b72bbbc 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -491,16 +491,15 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru output = [None] * len(prediction) for image_i, pred in enumerate(prediction): - # Remove rows - pred = pred[pred[:, 4] > conf_thres] # retain above threshold + # Retain > conf + pred = pred[pred[:, 4] > conf_thres] # compute conf torch.sigmoid_(pred[..., 5:]) pred[..., 5:] *= pred[..., 4:5] # conf = obj_conf * cls_conf # Apply width-height constraint - i = (pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1) & torch.isfinite(pred).all(1) - pred = pred[i] + pred = pred[(pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1)] # If none are remaining => process next image if len(pred) == 0: @@ -517,6 +516,9 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru conf, j = pred[:, 5:].max(1) pred = torch.cat((pred[:, :4], conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # (xyxy, conf, cls) + # Apply finite constraint + pred = pred[torch.isfinite(pred).all(1)] + # Get detections sorted by decreasing confidence scores pred = pred[pred[:, 4].argsort(descending=True)]