diff --git a/test.py b/test.py index 939dd059..a1e36b88 100644 --- a/test.py +++ b/test.py @@ -17,6 +17,7 @@ def test(cfg, iou_thres=0.6, # for nms save_json=False, single_cls=False, + augment=False, model=None, dataloader=None): # Initialize/load model and set device @@ -87,7 +88,7 @@ def test(cfg, # Disable gradients with torch.no_grad(): - if opt.augment: # augmented testing https://github.com/ultralytics/yolov3/issues/931 + if augment: # augmented testing https://github.com/ultralytics/yolov3/issues/931 imgs = torch.cat((imgs, imgs.flip(3), # flip-lr torch_utils.scale_img(imgs, 0.7), # scale @@ -98,7 +99,7 @@ def test(cfg, inf_out, train_out = model(imgs) # inference and training outputs t0 += torch_utils.time_synchronized() - t - if opt.augment: + if augment: x = torch.split(inf_out, nb, dim=0) x[1][..., 0] = width - x[1][..., 0] # flip lr x[2][..., :4] /= 0.7 # scale @@ -261,7 +262,8 @@ if __name__ == '__main__': opt.conf_thres, opt.iou_thres, opt.save_json, - opt.single_cls) + opt.single_cls, + opt.augment) elif opt.task == 'benchmark': # mAPs at 320-608 at conf 0.5 and 0.7 y = []