updates
This commit is contained in:
parent
be10b75eb4
commit
c36f1e990b
|
@ -14,7 +14,7 @@ def detect(
|
||||||
images,
|
images,
|
||||||
output='output', # output folder
|
output='output', # output folder
|
||||||
img_size=416,
|
img_size=416,
|
||||||
conf_thres=0.3,
|
conf_thres=0.5,
|
||||||
nms_thres=0.5,
|
nms_thres=0.5,
|
||||||
save_txt=False,
|
save_txt=False,
|
||||||
save_images=True,
|
save_images=True,
|
||||||
|
|
|
@ -358,7 +358,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
||||||
|
|
||||||
# Filter out confidence scores below threshold
|
# Filter out confidence scores below threshold
|
||||||
class_conf, class_pred = pred[:, 5:].max(1)
|
class_conf, class_pred = pred[:, 5:].max(1)
|
||||||
# pred[:, 4] *= class_conf
|
pred[:, 4] *= class_conf
|
||||||
|
|
||||||
i = (pred[:, 4] > conf_thres) & (pred[:, 2] > min_wh) & (pred[:, 3] > min_wh)
|
i = (pred[:, 4] > conf_thres) & (pred[:, 2] > min_wh) & (pred[:, 3] > min_wh)
|
||||||
pred = pred[i]
|
pred = pred[i]
|
||||||
|
@ -373,7 +373,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
|
||||||
|
|
||||||
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
|
||||||
pred[:, :4] = xywh2xyxy(pred[:, :4])
|
pred[:, :4] = xywh2xyxy(pred[:, :4])
|
||||||
pred[:, 4] *= class_conf # improves mAP from 0.549 to 0.551
|
# pred[:, 4] *= class_conf # improves mAP from 0.549 to 0.551
|
||||||
|
|
||||||
# Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred)
|
# Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred)
|
||||||
pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1)
|
pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1)
|
||||||
|
@ -479,7 +479,7 @@ def plot_wh_methods(): # from utils.utils import *; plot_wh_methods()
|
||||||
fig.savefig('comparison.jpg', dpi=fig.dpi)
|
fig.savefig('comparison.jpg', dpi=fig.dpi)
|
||||||
|
|
||||||
|
|
||||||
def plot_results(start=0): # from utils.utils import *; plot_results()
|
def plot_results(start=0, stop=0): # from utils.utils import *; plot_results()
|
||||||
# Plot training results files 'results*.txt'
|
# Plot training results files 'results*.txt'
|
||||||
# import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt')
|
# import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt')
|
||||||
|
|
||||||
|
@ -487,7 +487,7 @@ def plot_results(start=0): # from utils.utils import *; plot_results()
|
||||||
s = ['X + Y', 'Width + Height', 'Confidence', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP']
|
s = ['X + Y', 'Width + Height', 'Confidence', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP']
|
||||||
for f in sorted(glob.glob('results*.txt')):
|
for f in sorted(glob.glob('results*.txt')):
|
||||||
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 9, 10, 11]).T # column 11 is mAP
|
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 9, 10, 11]).T # column 11 is mAP
|
||||||
x = range(start, results.shape[1])
|
x = range(start, stop if stop else results.shape[1])
|
||||||
for i in range(8):
|
for i in range(8):
|
||||||
plt.subplot(2, 4, i + 1)
|
plt.subplot(2, 4, i + 1)
|
||||||
plt.plot(x, results[i, x], marker='.', label=f)
|
plt.plot(x, results[i, x], marker='.', label=f)
|
||||||
|
|
Loading…
Reference in New Issue