This commit is contained in:
Glenn Jocher 2019-08-31 19:11:59 +02:00
parent e926afd02b
commit 0a725a4bad
1 changed files with 5 additions and 10 deletions

View File

@ -7,24 +7,19 @@ from utils.datasets import *
from utils.utils import * from utils.utils import *
def detect(save_txt=False, def detect(save_txt=False, save_images=True):
save_images=True): img_size = (320, 192) if ONNX_EXPORT else opt.img_size # (320, 192) or (416, 256) or (608, 352) for (height, width)
out = opt.output out = opt.output
img_size = opt.img_size
# Initialize # Initialize
device = torch_utils.select_device(force_cpu=ONNX_EXPORT) device = torch_utils.select_device(force_cpu=ONNX_EXPORT)
torch.backends.cudnn.benchmark = False # set False for reproducible results torch.backends.cudnn.benchmark = False # set False to speed up variable image size inference
if os.path.exists(out): if os.path.exists(out):
shutil.rmtree(out) # delete output folder shutil.rmtree(out) # delete output folder
os.makedirs(out) # make new output folder os.makedirs(out) # make new output folder
# Initialize model # Initialize model
if ONNX_EXPORT: model = Darknet(opt.cfg, img_size)
s = (320, 192) # (320, 192) or (416, 256) or (608, 352) onnx model image size (height, width)
model = Darknet(opt.cfg, s)
else:
model = Darknet(opt.cfg, img_size)
# Load weights # Load weights
if opt.weights.endswith('.pt'): # pytorch format if opt.weights.endswith('.pt'): # pytorch format
@ -40,7 +35,7 @@ def detect(save_txt=False,
# Export mode # Export mode
if ONNX_EXPORT: if ONNX_EXPORT:
img = torch.zeros((1, 3, s[0], s[1])) img = torch.zeros((1, 3) + img_size) # (1, 3, 320, 192)
torch.onnx.export(model, img, 'weights/export.onnx', verbose=True) torch.onnx.export(model, img, 'weights/export.onnx', verbose=True)
return return