This commit is contained in:
Glenn Jocher 2019-07-30 12:39:17 +02:00
parent 272b9c7c11
commit 01aaf2c11c
2 changed files with 3 additions and 1 deletions

View File

@ -40,11 +40,12 @@ def detect(cfg,
_ = load_darknet_weights(model, weights) _ = load_darknet_weights(model, weights)
# Fuse Conv2d + BatchNorm2d layers # Fuse Conv2d + BatchNorm2d layers
model.fuse() # model.fuse()
# Eval mode # Eval mode
model.to(device).eval() model.to(device).eval()
# Export mode
if ONNX_EXPORT: if ONNX_EXPORT:
img = torch.zeros((1, 3, s[0], s[1])) img = torch.zeros((1, 3, s[0], s[1]))
torch.onnx.export(model, img, 'weights/export.onnx', verbose=True) torch.onnx.export(model, img, 'weights/export.onnx', verbose=True)

View File

@ -2,6 +2,7 @@
# conda install numpy opencv matplotlib tqdm pillow # conda install numpy opencv matplotlib tqdm pillow
# conda install pytorch torchvision -c pytorch # conda install pytorch torchvision -c pytorch
# conda install scikit-image -c conda-forge # conda install scikit-image -c conda-forge
# conda install -c spyder-ide spyder-line-profiler
numpy numpy
opencv-python opencv-python
torch >= 1.1.0 torch >= 1.1.0