merge_batch NMS method

This commit is contained in:
Glenn Jocher 2020-03-26 11:28:46 -07:00
parent c71ab7d506
commit f91b1fb13a
1 changed files with 5 additions and 3 deletions

View File

@ -17,6 +17,7 @@ def test(cfg,
iou_thres=0.6, # for nms iou_thres=0.6, # for nms
save_json=False, save_json=False,
single_cls=False, single_cls=False,
augment=False,
model=None, model=None,
dataloader=None): dataloader=None):
# Initialize/load model and set device # Initialize/load model and set device
@ -87,7 +88,7 @@ def test(cfg,
# Disable gradients # Disable gradients
with torch.no_grad(): with torch.no_grad():
if opt.augment: # augmented testing https://github.com/ultralytics/yolov3/issues/931 if augment: # augmented testing https://github.com/ultralytics/yolov3/issues/931
imgs = torch.cat((imgs, imgs = torch.cat((imgs,
imgs.flip(3), # flip-lr imgs.flip(3), # flip-lr
torch_utils.scale_img(imgs, 0.7), # scale torch_utils.scale_img(imgs, 0.7), # scale
@ -98,7 +99,7 @@ def test(cfg,
inf_out, train_out = model(imgs) # inference and training outputs inf_out, train_out = model(imgs) # inference and training outputs
t0 += torch_utils.time_synchronized() - t t0 += torch_utils.time_synchronized() - t
if opt.augment: if augment:
x = torch.split(inf_out, nb, dim=0) x = torch.split(inf_out, nb, dim=0)
x[1][..., 0] = width - x[1][..., 0] # flip lr x[1][..., 0] = width - x[1][..., 0] # flip lr
x[2][..., :4] /= 0.7 # scale x[2][..., :4] /= 0.7 # scale
@ -261,7 +262,8 @@ if __name__ == '__main__':
opt.conf_thres, opt.conf_thres,
opt.iou_thres, opt.iou_thres,
opt.save_json, opt.save_json,
opt.single_cls) opt.single_cls,
opt.augment)
elif opt.task == 'benchmark': # mAPs at 320-608 at conf 0.5 and 0.7 elif opt.task == 'benchmark': # mAPs at 320-608 at conf 0.5 and 0.7
y = [] y = []