This commit is contained in:
Glenn Jocher 2019-09-10 01:34:23 +02:00
parent 4445715f4c
commit d1b6929043
2 changed files with 114 additions and 44 deletions

View File

@ -11,6 +11,7 @@ def detect(save_txt=False, save_img=False, stream_img=False):
img_size = (320, 192) if ONNX_EXPORT else opt.img_size # (320, 192) or (416, 256) or (608, 352) for (height, width) img_size = (320, 192) if ONNX_EXPORT else opt.img_size # (320, 192) or (416, 256) or (608, 352) for (height, width)
out, source, weights, half = opt.output, opt.source, opt.weights, opt.half out, source, weights, half = opt.output, opt.source, opt.weights, opt.half
webcam = source == '0' or source.startswith('rtsp') or source.startswith('http') webcam = source == '0' or source.startswith('rtsp') or source.startswith('http')
streams = source == 'streams.txt'
# Initialize # Initialize
device = torch_utils.select_device(force_cpu=ONNX_EXPORT) device = torch_utils.select_device(force_cpu=ONNX_EXPORT)
@ -47,7 +48,9 @@ def detect(save_txt=False, save_img=False, stream_img=False):
# Set Dataloader # Set Dataloader
vid_path, vid_writer = None, None vid_path, vid_writer = None, None
if webcam: if streams:
dataset = LoadStreams(source, img_size=img_size, half=half)
elif webcam:
stream_img = True stream_img = True
dataset = LoadWebcam(source, img_size=img_size, half=half) dataset = LoadWebcam(source, img_size=img_size, half=half)
else: else:
@ -60,16 +63,23 @@ def detect(save_txt=False, save_img=False, stream_img=False):
# Run inference # Run inference
t0 = time.time() t0 = time.time()
for path, img, im0, vid_cap in dataset: for path, img, im0s, vid_cap in dataset:
t = time.time() t = time.time()
save_path = str(Path(out) / Path(path).name)
# Get detections # Get detections
img = torch.from_numpy(img).unsqueeze(0).to(device) img = torch.from_numpy(img).to(device)
if img.ndimension() == 3:
img = img.unsqueeze(0)
pred, _ = model(img) pred, _ = model(img)
det = non_max_suppression(pred.float(), opt.conf_thres, opt.nms_thres)[0]
s = '%gx%g ' % img.shape[2:] # print string for i, det in enumerate(non_max_suppression(pred, opt.conf_thres, opt.nms_thres)): # detections per image
if streams: # batch_size > 1
p, s, im0 = path[i], '%g: ' % i, im0s[i]
else:
p, s, im0 = path, '', im0s
save_path = str(Path(out) / Path(p).name)
s += '%gx%g ' % img.shape[2:] # print string
if det is not None and len(det): if det is not None and len(det):
# Rescale boxes from img_size to im0 size # Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
@ -93,7 +103,7 @@ def detect(save_txt=False, save_img=False, stream_img=False):
# Stream results # Stream results
if stream_img: if stream_img:
cv2.imshow(weights, im0) cv2.imshow(p, im0)
# Save results (image with detections) # Save results (image with detections)
if save_img: if save_img:
@ -106,9 +116,9 @@ def detect(save_txt=False, save_img=False, stream_img=False):
vid_writer.release() # release previous video writer vid_writer.release() # release previous video writer
fps = vid_cap.get(cv2.CAP_PROP_FPS) fps = vid_cap.get(cv2.CAP_PROP_FPS)
width = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*opt.fourcc), fps, (width, height)) vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*opt.fourcc), fps, (w, h))
vid_writer.write(im0) vid_writer.write(im0)
if save_txt or save_img: if save_txt or save_img:

View File

@ -3,7 +3,9 @@ import math
import os import os
import random import random
import shutil import shutil
import time
from pathlib import Path from pathlib import Path
from threading import Thread
import cv2 import cv2
import numpy as np import numpy as np
@ -183,6 +185,64 @@ class LoadWebcam: # for inference
return 0 return 0
class LoadStreams: # multiple IP or RTSP cameras
def __init__(self, path='streams.txt', img_size=416, half=False):
self.img_size = img_size
self.half = half # half precision fp16 images
with open(path, 'r') as f:
sources = f.read().splitlines()
n = len(sources)
self.imgs = [None] * n
self.sources = sources
for i, s in enumerate(sources):
# Start the thread to read frames from the video stream
cap = cv2.VideoCapture(0 if s == '0' else s)
fps = cap.get(cv2.CAP_PROP_FPS) % 100
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print('%g/%g: %gx%g at %.2f FPS %s...' % (i + 1, n, width, height, fps, s))
thread = Thread(target=self.update, args=([i, cap]))
thread.daemon = True
thread.start()
print('') # newline
time.sleep(0.5)
def update(self, index, cap):
# Read next stream frame in a daemon thread
while cap.isOpened():
_, self.imgs[index] = cap.read()
time.sleep(0.030) # 33.3 FPS to keep buffer empty
def __iter__(self):
self.count = -1
return self
def __next__(self):
self.count += 1
img0 = self.imgs.copy()
if cv2.waitKey(1) == ord('q'): # 'q' to quit
cv2.destroyAllWindows()
raise StopIteration
# Letterbox
img = [letterbox(x, new_shape=self.img_size, mode='square')[0] for x in img0]
# Stack
img = np.stack(img, 0)
# Normalize RGB
img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB
img = np.ascontiguousarray(img, dtype=np.float16 if self.half else np.float32) # uint8 to fp16/fp32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
return self.sources, img, img0, None
def __len__(self):
return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years
class LoadImagesAndLabels(Dataset): # for training/testing class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False, def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=True, image_weights=False,
cache_images=False): cache_images=False):