diff --git a/utils/utils.py b/utils/utils.py index 852f3d83..38ff9fd3 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -537,7 +537,8 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=Tru pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1) # Apply finite constraint - pred = pred[torch.isfinite(pred).all(1)] + if not torch.isfinite(pred).all(): + pred = pred[torch.isfinite(pred).all(1)] # Batched NMS if method == 'vision_batch':