This commit is contained in:
Glenn Jocher 2019-02-10 21:10:50 +01:00
parent 51eb173416
commit 6f0086103c
2 changed files with 15 additions and 15 deletions

View File

@ -44,8 +44,6 @@ def test(
outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], [] outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], []
AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC) AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC)
for batch_i, (imgs, targets) in enumerate(dataloader): for batch_i, (imgs, targets) in enumerate(dataloader):
with torch.no_grad():
output = model(imgs.to(device)) output = model(imgs.to(device))
output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres) output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres)
@ -134,6 +132,7 @@ if __name__ == '__main__':
opt = parser.parse_args() opt = parser.parse_args()
print(opt, end='\n\n') print(opt, end='\n\n')
with torch.no_grad():
mAP = test( mAP = test(
opt.cfg, opt.cfg,
opt.data_cfg, opt.data_cfg,

View File

@ -195,6 +195,7 @@ def train(
os.system('cp ' + latest + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch))) os.system('cp ' + latest + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch)))
# Calculate mAP # Calculate mAP
with torch.no_grad():
mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size) mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size)
# Write epoch results # Write epoch results