Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2019-09-02 16:04:54 +02:00
parent 2877ac9286
commit bfe6d560c0
1 changed files with 15 additions and 13 deletions

View File

@ -2,7 +2,7 @@ import argparse
import time import time
from sys import platform from sys import platform
from models import * from models import * # set ONNX_EXPORT in models.py
from utils.datasets import * from utils.datasets import *
from utils.utils import * from utils.utils import *
@ -14,7 +14,6 @@ def detect(save_txt=False, save_img=True, stream_img=False):
# Initialize # Initialize
device = torch_utils.select_device(force_cpu=ONNX_EXPORT) device = torch_utils.select_device(force_cpu=ONNX_EXPORT)
torch.backends.cudnn.benchmark = False # set False to speed up variable image size inference
if os.path.exists(out): if os.path.exists(out):
shutil.rmtree(out) # delete output folder shutil.rmtree(out) # delete output folder
os.makedirs(out) # make new output folder os.makedirs(out) # make new output folder
@ -30,6 +29,7 @@ def detect(save_txt=False, save_img=True, stream_img=False):
# Fuse Conv2d + BatchNorm2d layers # Fuse Conv2d + BatchNorm2d layers
# model.fuse() # model.fuse()
# torch.backends.cudnn.benchmark = True # set True to speed up constant image size inference
# Eval mode # Eval mode
model.to(device).eval() model.to(device).eval()
@ -50,9 +50,9 @@ def detect(save_txt=False, save_img=True, stream_img=False):
if webcam: if webcam:
save_img = False save_img = False
stream_img = True stream_img = True
dataloader = LoadWebcam(opt.source, img_size=img_size, half=opt.half) dataset = LoadWebcam(opt.source, img_size=img_size, half=opt.half)
else: else:
dataloader = LoadImages(opt.source, img_size=img_size, half=opt.half) dataset = LoadImages(opt.source, img_size=img_size, half=opt.half)
# Get classes and colors # Get classes and colors
classes = load_classes(parse_data_cfg(opt.data)['names']) classes = load_classes(parse_data_cfg(opt.data)['names'])
@ -60,7 +60,7 @@ def detect(save_txt=False, save_img=True, stream_img=False):
# Run inference # Run inference
t0 = time.time() t0 = time.time()
for path, img, im0, vid_cap in dataloader: for path, img, im0, vid_cap in dataset:
t = time.time() t = time.time()
save_path = str(Path(out) / Path(path).name) save_path = str(Path(out) / Path(path).name)
@ -69,15 +69,15 @@ def detect(save_txt=False, save_img=True, stream_img=False):
pred, _ = model(img) pred, _ = model(img)
det = non_max_suppression(pred.float(), opt.conf_thres, opt.nms_thres)[0] det = non_max_suppression(pred.float(), opt.conf_thres, opt.nms_thres)[0]
s = '%gx%g ' % img.shape[2:] # string to print image size
if det is not None and len(det): if det is not None and len(det):
# Rescale boxes from img_size to im0 size # Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round() det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Print results to screen # Print results
print('%gx%g ' % img.shape[2:], end='') # print image size
for c in det[:, -1].unique(): for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() n = (det[:, -1] == c).sum() # detections per class
print('%g %ss' % (n, classes[int(c)]), end=', ') s += '%g %ss' % (n, classes[int(c)]) # add to string
# Write results # Write results
for *xyxy, conf, _, cls in det: for *xyxy, conf, _, cls in det:
@ -89,13 +89,15 @@ def detect(save_txt=False, save_img=True, stream_img=False):
label = '%s %.2f' % (classes[int(cls)], conf) label = '%s %.2f' % (classes[int(cls)], conf)
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)]) plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])
print('Done. (%.3fs)' % (time.time() - t)) print('%sDone. (%.3fs)' % (s, time.time() - t))
if stream_img: # Stream results # Stream results
if stream_img:
cv2.imshow(opt.weights, im0) cv2.imshow(opt.weights, im0)
if save_img: # Save image with detections # Save results (image with detections)
if dataloader.mode == 'images': if save_img:
if dataset.mode == 'images':
cv2.imwrite(save_path, im0) cv2.imwrite(save_path, im0)
else: else:
if vid_path != save_path: # new video if vid_path != save_path: # new video