updates
This commit is contained in:
parent
f299d83f40
commit
95f3d8e043
|
@ -352,11 +352,12 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5):
|
|||
# shape_likelihood[:, c] =
|
||||
# multivariate_normal.pdf(x, mean=mat['class_mu'][c, :2], cov=mat['class_cov'][c, :2, :2])
|
||||
|
||||
# Filter out confidence scores below threshold
|
||||
# Multiply conf by class conf to get combined confidence
|
||||
class_conf, class_pred = pred[:, 5:].max(1)
|
||||
pred[:, 4] *= class_conf
|
||||
|
||||
i = (pred[:, 4] > conf_thres) & (pred[:, 2] > min_wh) & (pred[:, 3] > min_wh)
|
||||
# Select only suitable predictions
|
||||
i = (pred[:, 4] > conf_thres) & (pred[:, 2:4] > min_wh).all(1) & (torch.isnan(pred).any(1) == 0)
|
||||
pred = pred[i]
|
||||
|
||||
# If none are remaining => process next image
|
||||
|
@ -532,7 +533,7 @@ def plot_results(start=0, stop=0): # from utils.utils import *; plot_results()
|
|||
x = range(start, min(stop, n) if stop else n)
|
||||
for i in range(10):
|
||||
plt.subplot(2, 5, i + 1)
|
||||
plt.plot(x, results[i, x].clip(max=None), marker='.', label=f.replace('.txt', ''))
|
||||
plt.plot(x, results[i, x], marker='.', label=f.replace('.txt', ''))
|
||||
plt.title(s[i])
|
||||
if i == 0:
|
||||
plt.legend()
|
||||
|
|
Loading…
Reference in New Issue