updates
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
2877ac9286
commit
bfe6d560c0
28
detect.py
28
detect.py
|
@ -2,7 +2,7 @@ import argparse
|
||||||
import time
|
import time
|
||||||
from sys import platform
|
from sys import platform
|
||||||
|
|
||||||
from models import *
|
from models import * # set ONNX_EXPORT in models.py
|
||||||
from utils.datasets import *
|
from utils.datasets import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
|
@ -14,7 +14,6 @@ def detect(save_txt=False, save_img=True, stream_img=False):
|
||||||
|
|
||||||
# Initialize
|
# Initialize
|
||||||
device = torch_utils.select_device(force_cpu=ONNX_EXPORT)
|
device = torch_utils.select_device(force_cpu=ONNX_EXPORT)
|
||||||
torch.backends.cudnn.benchmark = False # set False to speed up variable image size inference
|
|
||||||
if os.path.exists(out):
|
if os.path.exists(out):
|
||||||
shutil.rmtree(out) # delete output folder
|
shutil.rmtree(out) # delete output folder
|
||||||
os.makedirs(out) # make new output folder
|
os.makedirs(out) # make new output folder
|
||||||
|
@ -30,6 +29,7 @@ def detect(save_txt=False, save_img=True, stream_img=False):
|
||||||
|
|
||||||
# Fuse Conv2d + BatchNorm2d layers
|
# Fuse Conv2d + BatchNorm2d layers
|
||||||
# model.fuse()
|
# model.fuse()
|
||||||
|
# torch.backends.cudnn.benchmark = True # set True to speed up constant image size inference
|
||||||
|
|
||||||
# Eval mode
|
# Eval mode
|
||||||
model.to(device).eval()
|
model.to(device).eval()
|
||||||
|
@ -50,9 +50,9 @@ def detect(save_txt=False, save_img=True, stream_img=False):
|
||||||
if webcam:
|
if webcam:
|
||||||
save_img = False
|
save_img = False
|
||||||
stream_img = True
|
stream_img = True
|
||||||
dataloader = LoadWebcam(opt.source, img_size=img_size, half=opt.half)
|
dataset = LoadWebcam(opt.source, img_size=img_size, half=opt.half)
|
||||||
else:
|
else:
|
||||||
dataloader = LoadImages(opt.source, img_size=img_size, half=opt.half)
|
dataset = LoadImages(opt.source, img_size=img_size, half=opt.half)
|
||||||
|
|
||||||
# Get classes and colors
|
# Get classes and colors
|
||||||
classes = load_classes(parse_data_cfg(opt.data)['names'])
|
classes = load_classes(parse_data_cfg(opt.data)['names'])
|
||||||
|
@ -60,7 +60,7 @@ def detect(save_txt=False, save_img=True, stream_img=False):
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for path, img, im0, vid_cap in dataloader:
|
for path, img, im0, vid_cap in dataset:
|
||||||
t = time.time()
|
t = time.time()
|
||||||
save_path = str(Path(out) / Path(path).name)
|
save_path = str(Path(out) / Path(path).name)
|
||||||
|
|
||||||
|
@ -69,15 +69,15 @@ def detect(save_txt=False, save_img=True, stream_img=False):
|
||||||
pred, _ = model(img)
|
pred, _ = model(img)
|
||||||
det = non_max_suppression(pred.float(), opt.conf_thres, opt.nms_thres)[0]
|
det = non_max_suppression(pred.float(), opt.conf_thres, opt.nms_thres)[0]
|
||||||
|
|
||||||
|
s = '%gx%g ' % img.shape[2:] # string to print image size
|
||||||
if det is not None and len(det):
|
if det is not None and len(det):
|
||||||
# Rescale boxes from img_size to im0 size
|
# Rescale boxes from img_size to im0 size
|
||||||
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
|
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
|
||||||
|
|
||||||
# Print results to screen
|
# Print results
|
||||||
print('%gx%g ' % img.shape[2:], end='') # print image size
|
|
||||||
for c in det[:, -1].unique():
|
for c in det[:, -1].unique():
|
||||||
n = (det[:, -1] == c).sum()
|
n = (det[:, -1] == c).sum() # detections per class
|
||||||
print('%g %ss' % (n, classes[int(c)]), end=', ')
|
s += '%g %ss' % (n, classes[int(c)]) # add to string
|
||||||
|
|
||||||
# Write results
|
# Write results
|
||||||
for *xyxy, conf, _, cls in det:
|
for *xyxy, conf, _, cls in det:
|
||||||
|
@ -89,13 +89,15 @@ def detect(save_txt=False, save_img=True, stream_img=False):
|
||||||
label = '%s %.2f' % (classes[int(cls)], conf)
|
label = '%s %.2f' % (classes[int(cls)], conf)
|
||||||
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])
|
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])
|
||||||
|
|
||||||
print('Done. (%.3fs)' % (time.time() - t))
|
print('%sDone. (%.3fs)' % (s, time.time() - t))
|
||||||
|
|
||||||
if stream_img: # Stream results
|
# Stream results
|
||||||
|
if stream_img:
|
||||||
cv2.imshow(opt.weights, im0)
|
cv2.imshow(opt.weights, im0)
|
||||||
|
|
||||||
if save_img: # Save image with detections
|
# Save results (image with detections)
|
||||||
if dataloader.mode == 'images':
|
if save_img:
|
||||||
|
if dataset.mode == 'images':
|
||||||
cv2.imwrite(save_path, im0)
|
cv2.imwrite(save_path, im0)
|
||||||
else:
|
else:
|
||||||
if vid_path != save_path: # new video
|
if vid_path != save_path: # new video
|
||||||
|
|
Loading…
Reference in New Issue