diff --git a/detect.py b/detect.py index 29151bbd..3a8c6f14 100644 --- a/detect.py +++ b/detect.py @@ -26,7 +26,11 @@ def detect( os.makedirs(output) # make new output folder # Initialize model - model = Darknet(cfg, img_size) + if ONNX_EXPORT: + s = (416, 416) # onnx model image size + model = Darknet(cfg, s) + else: + model = Darknet(cfg, img_size) # Load weights if weights.endswith('.pt'): # pytorch format @@ -37,8 +41,14 @@ def detect( # Fuse Conv2d + BatchNorm2d layers model.fuse() + # Eval mode model.to(device).eval() + if ONNX_EXPORT: + img = torch.zeros((1, 3, s[0], s[1])) + torch.onnx.export(model, img, 'weights/export.onnx', verbose=True) + return + # Set Dataloader vid_path, vid_writer = None, None if webcam: @@ -55,11 +65,6 @@ def detect( t = time.time() save_path = str(Path(output) / Path(path).name) - if ONNX_EXPORT: - img = torch.zeros((1, 3, 416, 416)) - torch.onnx.export(model, img, 'weights/export.onnx', verbose=True) - return - # Get detections img = torch.from_numpy(img).unsqueeze(0).to(device) pred, _ = model(img)