This commit is contained in:
Glenn Jocher 2019-04-02 22:54:32 +02:00
parent be10b75eb4
commit c36f1e990b
2 changed files with 5 additions and 5 deletions

View File

@ -14,7 +14,7 @@ def detect(
images, images,
output='output', # output folder output='output', # output folder
img_size=416, img_size=416,
conf_thres=0.3, conf_thres=0.5,
nms_thres=0.5, nms_thres=0.5,
save_txt=False, save_txt=False,
save_images=True, save_images=True,

View File

@ -358,7 +358,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
# Filter out confidence scores below threshold # Filter out confidence scores below threshold
class_conf, class_pred = pred[:, 5:].max(1) class_conf, class_pred = pred[:, 5:].max(1)
# pred[:, 4] *= class_conf pred[:, 4] *= class_conf
i = (pred[:, 4] > conf_thres) & (pred[:, 2] > min_wh) & (pred[:, 3] > min_wh) i = (pred[:, 4] > conf_thres) & (pred[:, 2] > min_wh) & (pred[:, 3] > min_wh)
pred = pred[i] pred = pred[i]
@ -373,7 +373,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
# Box (center x, center y, width, height) to (x1, y1, x2, y2) # Box (center x, center y, width, height) to (x1, y1, x2, y2)
pred[:, :4] = xywh2xyxy(pred[:, :4]) pred[:, :4] = xywh2xyxy(pred[:, :4])
pred[:, 4] *= class_conf # improves mAP from 0.549 to 0.551 # pred[:, 4] *= class_conf # improves mAP from 0.549 to 0.551
# Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred) # Detections ordered as (x1y1x2y2, obj_conf, class_conf, class_pred)
pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1) pred = torch.cat((pred[:, :5], class_conf.unsqueeze(1), class_pred), 1)
@ -479,7 +479,7 @@ def plot_wh_methods(): # from utils.utils import *; plot_wh_methods()
fig.savefig('comparison.jpg', dpi=fig.dpi) fig.savefig('comparison.jpg', dpi=fig.dpi)
def plot_results(start=0): # from utils.utils import *; plot_results() def plot_results(start=0, stop=0): # from utils.utils import *; plot_results()
# Plot training results files 'results*.txt' # Plot training results files 'results*.txt'
# import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt') # import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt')
@ -487,7 +487,7 @@ def plot_results(start=0): # from utils.utils import *; plot_results()
s = ['X + Y', 'Width + Height', 'Confidence', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP'] s = ['X + Y', 'Width + Height', 'Confidence', 'Classification', 'Total Loss', 'Precision', 'Recall', 'mAP']
for f in sorted(glob.glob('results*.txt')): for f in sorted(glob.glob('results*.txt')):
results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 9, 10, 11]).T # column 11 is mAP results = np.loadtxt(f, usecols=[2, 3, 4, 5, 6, 9, 10, 11]).T # column 11 is mAP
x = range(start, results.shape[1]) x = range(start, stop if stop else results.shape[1])
for i in range(8): for i in range(8):
plt.subplot(2, 4, i + 1) plt.subplot(2, 4, i + 1)
plt.plot(x, results[i, x], marker='.', label=f) plt.plot(x, results[i, x], marker='.', label=f)