diff --git a/detect.py b/detect.py index 7be56115..f2501970 100644 --- a/detect.py +++ b/detect.py @@ -34,12 +34,12 @@ def detect(save_img=False): modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights modelc.to(device).eval() - # Fuse Conv2d + BatchNorm2d layers - # model.fuse() - # Eval mode model.to(device).eval() + # Fuse Conv2d + BatchNorm2d layers + # model.fuse() + # Export mode if ONNX_EXPORT: model.fuse()