This commit is contained in:
Glenn Jocher 2019-04-22 16:21:21 +02:00
parent 23cd4ecfa7
commit ab8d8cbc93
1 changed files with 11 additions and 6 deletions

View File

@ -26,6 +26,10 @@ def detect(
os.makedirs(output) # make new output folder os.makedirs(output) # make new output folder
# Initialize model # Initialize model
if ONNX_EXPORT:
s = (416, 416) # onnx model image size
model = Darknet(cfg, s)
else:
model = Darknet(cfg, img_size) model = Darknet(cfg, img_size)
# Load weights # Load weights
@ -37,8 +41,14 @@ def detect(
# Fuse Conv2d + BatchNorm2d layers # Fuse Conv2d + BatchNorm2d layers
model.fuse() model.fuse()
# Eval mode
model.to(device).eval() model.to(device).eval()
if ONNX_EXPORT:
img = torch.zeros((1, 3, s[0], s[1]))
torch.onnx.export(model, img, 'weights/export.onnx', verbose=True)
return
# Set Dataloader # Set Dataloader
vid_path, vid_writer = None, None vid_path, vid_writer = None, None
if webcam: if webcam:
@ -55,11 +65,6 @@ def detect(
t = time.time() t = time.time()
save_path = str(Path(output) / Path(path).name) save_path = str(Path(output) / Path(path).name)
if ONNX_EXPORT:
img = torch.zeros((1, 3, 416, 416))
torch.onnx.export(model, img, 'weights/export.onnx', verbose=True)
return
# Get detections # Get detections
img = torch.from_numpy(img).unsqueeze(0).to(device) img = torch.from_numpy(img).unsqueeze(0).to(device)
pred, _ = model(img) pred, _ = model(img)