This commit is contained in:
Glenn Jocher 2019-04-22 16:21:21 +02:00
parent 23cd4ecfa7
commit ab8d8cbc93
1 changed files with 11 additions and 6 deletions

View File

@ -26,7 +26,11 @@ def detect(
os.makedirs(output) # make new output folder
# Initialize model
model = Darknet(cfg, img_size)
if ONNX_EXPORT:
s = (416, 416) # onnx model image size
model = Darknet(cfg, s)
else:
model = Darknet(cfg, img_size)
# Load weights
if weights.endswith('.pt'): # pytorch format
@ -37,8 +41,14 @@ def detect(
# Fuse Conv2d + BatchNorm2d layers
model.fuse()
# Eval mode
model.to(device).eval()
if ONNX_EXPORT:
img = torch.zeros((1, 3, s[0], s[1]))
torch.onnx.export(model, img, 'weights/export.onnx', verbose=True)
return
# Set Dataloader
vid_path, vid_writer = None, None
if webcam:
@ -55,11 +65,6 @@ def detect(
t = time.time()
save_path = str(Path(output) / Path(path).name)
if ONNX_EXPORT:
img = torch.zeros((1, 3, 416, 416))
torch.onnx.export(model, img, 'weights/export.onnx', verbose=True)
return
# Get detections
img = torch.from_numpy(img).unsqueeze(0).to(device)
pred, _ = model(img)