From f91b1fb13a0f264effdc5a17dd0e3957924f905a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 26 Mar 2020 11:28:46 -0700 Subject: [PATCH] merge_batch NMS method --- test.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test.py b/test.py index 939dd059..a1e36b88 100644 --- a/test.py +++ b/test.py @@ -17,6 +17,7 @@ def test(cfg, iou_thres=0.6, # for nms save_json=False, single_cls=False, + augment=False, model=None, dataloader=None): # Initialize/load model and set device @@ -87,7 +88,7 @@ def test(cfg, # Disable gradients with torch.no_grad(): - if opt.augment: # augmented testing https://github.com/ultralytics/yolov3/issues/931 + if augment: # augmented testing https://github.com/ultralytics/yolov3/issues/931 imgs = torch.cat((imgs, imgs.flip(3), # flip-lr torch_utils.scale_img(imgs, 0.7), # scale @@ -98,7 +99,7 @@ def test(cfg, inf_out, train_out = model(imgs) # inference and training outputs t0 += torch_utils.time_synchronized() - t - if opt.augment: + if augment: x = torch.split(inf_out, nb, dim=0) x[1][..., 0] = width - x[1][..., 0] # flip lr x[2][..., :4] /= 0.7 # scale @@ -261,7 +262,8 @@ if __name__ == '__main__': opt.conf_thres, opt.iou_thres, opt.save_json, - opt.single_cls) + opt.single_cls, + opt.augment) elif opt.task == 'benchmark': # mAPs at 320-608 at conf 0.5 and 0.7 y = []