updates
This commit is contained in:
parent
23cd4ecfa7
commit
ab8d8cbc93
15
detect.py
15
detect.py
|
@ -26,6 +26,10 @@ def detect(
|
|||
os.makedirs(output) # make new output folder
|
||||
|
||||
# Initialize model
|
||||
if ONNX_EXPORT:
|
||||
s = (416, 416) # onnx model image size
|
||||
model = Darknet(cfg, s)
|
||||
else:
|
||||
model = Darknet(cfg, img_size)
|
||||
|
||||
# Load weights
|
||||
|
@ -37,8 +41,14 @@ def detect(
|
|||
# Fuse Conv2d + BatchNorm2d layers
|
||||
model.fuse()
|
||||
|
||||
# Eval mode
|
||||
model.to(device).eval()
|
||||
|
||||
if ONNX_EXPORT:
|
||||
img = torch.zeros((1, 3, s[0], s[1]))
|
||||
torch.onnx.export(model, img, 'weights/export.onnx', verbose=True)
|
||||
return
|
||||
|
||||
# Set Dataloader
|
||||
vid_path, vid_writer = None, None
|
||||
if webcam:
|
||||
|
@ -55,11 +65,6 @@ def detect(
|
|||
t = time.time()
|
||||
save_path = str(Path(output) / Path(path).name)
|
||||
|
||||
if ONNX_EXPORT:
|
||||
img = torch.zeros((1, 3, 416, 416))
|
||||
torch.onnx.export(model, img, 'weights/export.onnx', verbose=True)
|
||||
return
|
||||
|
||||
# Get detections
|
||||
img = torch.from_numpy(img).unsqueeze(0).to(device)
|
||||
pred, _ = model(img)
|
||||
|
|
Loading…
Reference in New Issue