diff --git a/detect.py b/detect.py index b9eb3d85..b94903c1 100644 --- a/detect.py +++ b/detect.py @@ -75,7 +75,8 @@ def detect(save_img=False): # Run inference t0 = time.time() - _ = model(torch.zeros((1, 3, img_size, img_size), device=device)) if device.type != 'cpu' else None # run once + img = torch.zeros((1, 3, img_size, img_size), device=device) # init img + _ = model(img.half() if half else img.float()) if device.type != 'cpu' else None # run once for path, img, im0s, vid_cap in dataset: img = torch.from_numpy(img).to(device) img = img.half() if half else img.float() # uint8 to fp16/32