diff --git a/utils/utils.py b/utils/utils.py index d2b07461..54f248ff 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -352,11 +352,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): # shape_likelihood[:, c] = # multivariate_normal.pdf(x, mean=mat['class_mu'][c, :2], cov=mat['class_cov'][c, :2, :2]) - # Filter out confidence scores below threshold + # Multiply conf by class conf to get combined confidence class_conf, class_pred = pred[:, 5:].max(1) pred[:, 4] *= class_conf - i = (pred[:, 4] > conf_thres) & (pred[:, 2] > min_wh) & (pred[:, 3] > min_wh) + # Select only suitable predictions + i = (pred[:, 4] > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & (torch.isnan(pred).any(1) == 0) pred = pred[i] # If none are remaining => process next image @@ -532,7 +533,7 @@ def plot_results(start=0, stop=0): # from utils.utils import *; plot_results() x = range(start, min(stop, n) if stop else n) for i in range(10): plt.subplot(2, 5, i + 1) - plt.plot(x, results[i, x].clip(max=None), marker='.', label=f.replace('.txt', '')) + plt.plot(x, results[i, x], marker='.', label=f.replace('.txt', '')) plt.title(s[i]) if i == 0: plt.legend()