This commit is contained in:
Glenn Jocher 2019-12-20 09:07:25 -08:00
parent 8d54770859
commit 2bc6683325
2 changed files with 40 additions and 24 deletions

19
test.py
View File

@ -78,6 +78,9 @@ def test(cfg,
if batch_i == 0 and not os.path.exists('test_batch0.jpg'): if batch_i == 0 and not os.path.exists('test_batch0.jpg'):
plot_images(imgs=imgs, targets=targets, paths=paths, fname='test_batch0.jpg') plot_images(imgs=imgs, targets=targets, paths=paths, fname='test_batch0.jpg')
# Disable gradients
with torch.no_grad():
# Run model # Run model
inf_out, train_out = model(imgs) # inference and training outputs inf_out, train_out = model(imgs) # inference and training outputs
@ -220,7 +223,7 @@ if __name__ == '__main__':
opt = parser.parse_args() opt = parser.parse_args()
print(opt) print(opt)
with torch.no_grad(): # Test
test(opt.cfg, test(opt.cfg,
opt.data, opt.data,
opt.weights, opt.weights,
@ -229,3 +232,17 @@ if __name__ == '__main__':
opt.conf_thres, opt.conf_thres,
opt.nms_thres, opt.nms_thres,
opt.save_json or any([x in opt.data for x in ['coco.data', 'coco2014.data', 'coco2017.data']])) opt.save_json or any([x in opt.data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]))
# # Parameter study
# y = []
# x = np.arange(0.4, 0.81, 0.1)
# for v in x:
# y.append(test(opt.cfg, opt.data, opt.weights, opt.batch_size, opt.img_size, 0.1, v, True)[0])
# y = np.stack(y, 0)
#
# fig, ax = plt.subplots(1, 1, figsize=(12, 6))
# ax.plot(x, y[:, 2], marker='.', label='mAP@0.5')
# ax.plot(x, y[:, 3], marker='.', label='mAP@0.5:0.95')
# ax.legend()
# fig.tight_layout()
# plt.savefig('parameters.jpg', dpi=200)

View File

@ -323,7 +323,6 @@ def train():
if opt.prebias: if opt.prebias:
print_model_biases(model) print_model_biases(model)
elif not opt.notest or final_epoch: # Calculate mAP elif not opt.notest or final_epoch: # Calculate mAP
with torch.no_grad():
is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80 is_coco = any([x in data for x in ['coco.data', 'coco2014.data', 'coco2017.data']]) and model.nc == 80
results, maps = test.test(cfg, results, maps = test.test(cfg,
data, data,