This commit is contained in:
Glenn Jocher 2019-04-13 20:11:08 +02:00
parent f299d83f40
commit 95f3d8e043
1 changed files with 4 additions and 3 deletions

View File

@ -352,11 +352,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
# shape_likelihood[:, c] = # shape_likelihood[:, c] =
# multivariate_normal.pdf(x, mean=mat['class_mu'][c, :2], cov=mat['class_cov'][c, :2, :2]) # multivariate_normal.pdf(x, mean=mat['class_mu'][c, :2], cov=mat['class_cov'][c, :2, :2])
# Filter out confidence scores below threshold # Multiply conf by class conf to get combined confidence
class_conf, class_pred = pred[:, 5:].max(1) class_conf, class_pred = pred[:, 5:].max(1)
pred[:, 4] *= class_conf pred[:, 4] *= class_conf
i = (pred[:, 4] > conf_thres) & (pred[:, 2] > min_wh) & (pred[:, 3] > min_wh) # Select only suitable predictions
i = (pred[:, 4] > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & (torch.isnan(pred).any(1) == 0)
pred = pred[i] pred = pred[i]
# If none are remaining => process next image # If none are remaining => process next image
@ -532,7 +533,7 @@ def plot_results(start=0, stop=0): # from utils.utils import *; plot_results()
x = range(start, min(stop, n) if stop else n) x = range(start, min(stop, n) if stop else n)
for i in range(10): for i in range(10):
plt.subplot(2, 5, i + 1) plt.subplot(2, 5, i + 1)
plt.plot(x, results[i, x].clip(max=None), marker='.', label=f.replace('.txt', '')) plt.plot(x, results[i, x], marker='.', label=f.replace('.txt', ''))
plt.title(s[i]) plt.title(s[i])
if i == 0: if i == 0:
plt.legend() plt.legend()