diff --git a/utils/utils.py b/utils/utils.py index 658da502..856893ac 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -461,6 +461,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): class_conf, class_pred = pred[:, 5:].max(1) pred[:, 4] *= class_conf + # # Merge classes (optional) + # class_pred[(class_pred.view(-1,1) == torch.LongTensor([2, 3, 5, 6, 7]).view(1,-1)).any(1)] = 2 + # + # # Remove classes (optional) + # pred[class_pred != 2, 4] = 0.0 + # Select only suitable predictions i = (pred[:, 4] > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & torch.isfinite(pred).all(1) pred = pred[i]