Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2019-09-02 16:22:13 +02:00
parent 1e4351c4a2
commit 109173d555
1 changed files with 10 additions and 10 deletions

View File

@ -9,8 +9,8 @@ from utils.utils import *
def detect(save_txt=False, save_img=True, stream_img=False): def detect(save_txt=False, save_img=True, stream_img=False):
img_size = (320, 192) if ONNX_EXPORT else opt.img_size # (320, 192) or (416, 256) or (608, 352) for (height, width) img_size = (320, 192) if ONNX_EXPORT else opt.img_size # (320, 192) or (416, 256) or (608, 352) for (height, width)
webcam = opt.source == '0' or opt.source.startswith('rtsp') or opt.source.startswith('http') out, source, weights, half = opt.output, opt.source, opt.weights, opt.half
out = opt.output webcam = source == '0' or source.startswith('rtsp') or source.startswith('http')
# Initialize # Initialize
device = torch_utils.select_device(force_cpu=ONNX_EXPORT) device = torch_utils.select_device(force_cpu=ONNX_EXPORT)
@ -22,10 +22,10 @@ def detect(save_txt=False, save_img=True, stream_img=False):
model = Darknet(opt.cfg, img_size) model = Darknet(opt.cfg, img_size)
# Load weights # Load weights
if opt.weights.endswith('.pt'): # pytorch format if weights.endswith('.pt'): # pytorch format
model.load_state_dict(torch.load(opt.weights, map_location=device)['model']) model.load_state_dict(torch.load(weights, map_location=device)['model'])
else: # darknet format else: # darknet format
_ = load_darknet_weights(model, opt.weights) _ = load_darknet_weights(model, weights)
# Fuse Conv2d + BatchNorm2d layers # Fuse Conv2d + BatchNorm2d layers
# model.fuse() # model.fuse()
@ -41,8 +41,8 @@ def detect(save_txt=False, save_img=True, stream_img=False):
return return
# Half precision # Half precision
opt.half = opt.half and device.type != 'cpu' # half precision only supported on CUDA half = half and device.type != 'cpu' # half precision only supported on CUDA
if opt.half: if half:
model.half() model.half()
# Set Dataloader # Set Dataloader
@ -50,9 +50,9 @@ def detect(save_txt=False, save_img=True, stream_img=False):
if webcam: if webcam:
save_img = False save_img = False
stream_img = True stream_img = True
dataset = LoadWebcam(opt.source, img_size=img_size, half=opt.half) dataset = LoadWebcam(source, img_size=img_size, half=half)
else: else:
dataset = LoadImages(opt.source, img_size=img_size, half=opt.half) dataset = LoadImages(source, img_size=img_size, half=half)
# Get classes and colors # Get classes and colors
classes = load_classes(parse_data_cfg(opt.data)['names']) classes = load_classes(parse_data_cfg(opt.data)['names'])
@ -93,7 +93,7 @@ def detect(save_txt=False, save_img=True, stream_img=False):
# Stream results # Stream results
if stream_img: if stream_img:
cv2.imshow(opt.weights, im0) cv2.imshow(weights, im0)
# Save results (image with detections) # Save results (image with detections)
if save_img: if save_img: