From 6f0086103c39e728d7a1f8d94b3dbbcb530d5ad3 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 10 Feb 2019 21:10:50 +0100 Subject: [PATCH] updates --- test.py | 27 +++++++++++++-------------- train.py | 3 ++- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/test.py b/test.py index e2709af7..7a3a8441 100644 --- a/test.py +++ b/test.py @@ -44,10 +44,8 @@ def test( outputs, mAPs, mR, mP, TP, confidence, pred_class, target_class = [], [], [], [], [], [], [], [] AP_accum, AP_accum_count = np.zeros(nC), np.zeros(nC) for batch_i, (imgs, targets) in enumerate(dataloader): - - with torch.no_grad(): - output = model(imgs.to(device)) - output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres) + output = model(imgs.to(device)) + output = non_max_suppression(output, conf_thres=conf_thres, nms_thres=nms_thres) # Compute average precision for each sample for sample_i, (labels, detections) in enumerate(zip(targets, output)): @@ -134,13 +132,14 @@ if __name__ == '__main__': opt = parser.parse_args() print(opt, end='\n\n') - mAP = test( - opt.cfg, - opt.data_cfg, - opt.weights, - opt.batch_size, - opt.img_size, - opt.iou_thres, - opt.conf_thres, - opt.nms_thres - ) + with torch.no_grad(): + mAP = test( + opt.cfg, + opt.data_cfg, + opt.weights, + opt.batch_size, + opt.img_size, + opt.iou_thres, + opt.conf_thres, + opt.nms_thres + ) diff --git a/train.py b/train.py index 147eb237..5793a322 100644 --- a/train.py +++ b/train.py @@ -195,7 +195,8 @@ def train( os.system('cp ' + latest + ' ' + os.path.join(weights, 'backup{}.pt'.format(epoch))) # Calculate mAP - mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size) + with torch.no_grad(): + mAP, R, P = test.test(cfg, data_cfg, weights=latest, batch_size=batch_size, img_size=img_size) # Write epoch results with open('results.txt', 'a') as file: