2018-08-26 08:51:39 +00:00
import argparse
import time
2019-02-12 15:58:07 +00:00
from sys import platform
2018-08-26 08:51:39 +00:00
from models import *
from utils . datasets import *
from utils . utils import *
2019-01-08 18:37:23 +00:00
2019-02-10 20:06:22 +00:00
def detect (
cfg ,
2019-04-02 11:43:18 +00:00
data_cfg ,
2019-02-10 20:06:22 +00:00
weights ,
2019-04-29 15:49:09 +00:00
images = ' data/samples ' , # input folder
2019-02-17 16:30:16 +00:00
output = ' output ' , # output folder
2019-05-18 10:47:12 +00:00
fourcc = ' mp4v ' ,
2019-02-10 20:06:22 +00:00
img_size = 416 ,
2019-04-02 20:54:32 +00:00
conf_thres = 0.5 ,
2019-03-30 17:45:04 +00:00
nms_thres = 0.5 ,
2019-02-10 20:06:22 +00:00
save_txt = False ,
2019-02-11 12:45:04 +00:00
save_images = True ,
2019-02-11 13:19:35 +00:00
webcam = False
2019-02-10 20:06:22 +00:00
) :
2018-12-05 10:55:27 +00:00
device = torch_utils . select_device ( )
2019-02-17 16:30:16 +00:00
if os . path . exists ( output ) :
shutil . rmtree ( output ) # delete output folder
2019-02-12 16:29:13 +00:00
os . makedirs ( output ) # make new output folder
2018-08-26 08:51:39 +00:00
2019-02-11 11:32:54 +00:00
# Initialize model
2019-04-22 14:21:21 +00:00
if ONNX_EXPORT :
2019-05-25 12:43:07 +00:00
s = ( 320 , 192 ) # onnx model image size (height, width)
2019-04-22 14:21:21 +00:00
model = Darknet ( cfg , s )
else :
model = Darknet ( cfg , img_size )
2018-08-26 08:51:39 +00:00
2019-02-11 11:32:54 +00:00
# Load weights
2019-02-08 21:43:05 +00:00
if weights . endswith ( ' .pt ' ) : # pytorch format
2019-03-25 13:59:38 +00:00
model . load_state_dict ( torch . load ( weights , map_location = device ) [ ' model ' ] )
2018-12-06 12:01:49 +00:00
else : # darknet format
2019-03-19 08:38:32 +00:00
_ = load_darknet_weights ( model , weights )
2018-08-26 08:51:39 +00:00
2019-04-20 20:46:23 +00:00
# Fuse Conv2d + BatchNorm2d layers
model . fuse ( )
2019-04-19 18:41:18 +00:00
2019-04-22 14:21:21 +00:00
# Eval mode
2018-08-26 08:51:39 +00:00
model . to ( device ) . eval ( )
2019-04-22 14:21:21 +00:00
if ONNX_EXPORT :
img = torch . zeros ( ( 1 , 3 , s [ 0 ] , s [ 1 ] ) )
torch . onnx . export ( model , img , ' weights/export.onnx ' , verbose = True )
return
2018-08-26 08:51:39 +00:00
# Set Dataloader
2019-04-02 11:43:18 +00:00
vid_path , vid_writer = None , None
2019-02-11 12:45:04 +00:00
if webcam :
save_images = False
2019-02-11 16:25:32 +00:00
dataloader = LoadWebcam ( img_size = img_size )
2019-02-11 12:45:04 +00:00
else :
dataloader = LoadImages ( images , img_size = img_size )
2019-02-08 21:43:05 +00:00
2019-02-10 20:41:57 +00:00
# Get classes and colors
2019-04-02 11:43:18 +00:00
classes = load_classes ( parse_data_cfg ( data_cfg ) [ ' names ' ] )
2019-03-25 13:59:38 +00:00
colors = [ [ random . randint ( 0 , 255 ) for _ in range ( 3 ) ] for _ in range ( len ( classes ) ) ]
2018-08-26 08:51:39 +00:00
2019-04-02 11:43:18 +00:00
for i , ( path , img , im0 , vid_cap ) in enumerate ( dataloader ) :
2019-02-08 21:43:05 +00:00
t = time . time ( )
2019-03-25 13:59:38 +00:00
save_path = str ( Path ( output ) / Path ( path ) . name )
2018-08-26 08:51:39 +00:00
2019-04-21 18:30:11 +00:00
# Get detections
img = torch . from_numpy ( img ) . unsqueeze ( 0 ) . to ( device )
2019-04-05 13:34:42 +00:00
pred , _ = model ( img )
2019-04-22 12:59:39 +00:00
det = non_max_suppression ( pred , conf_thres , nms_thres ) [ 0 ]
2019-02-08 21:43:05 +00:00
2019-04-22 12:59:39 +00:00
if det is not None and len ( det ) > 0 :
2019-02-10 20:06:22 +00:00
# Rescale boxes from 416 to true image size
2019-04-22 14:52:14 +00:00
det [ : , : 4 ] = scale_coords ( img . shape [ 2 : ] , det [ : , : 4 ] , im0 . shape ) . round ( )
2019-02-08 21:43:05 +00:00
2019-02-11 12:45:04 +00:00
# Print results to screen
2019-04-29 15:57:51 +00:00
print ( ' %g x %g ' % img . shape [ 2 : ] , end = ' ' ) # print image size
2019-04-22 12:59:39 +00:00
for c in det [ : , - 1 ] . unique ( ) :
n = ( det [ : , - 1 ] == c ) . sum ( )
2019-02-11 17:15:51 +00:00
print ( ' %g %s s ' % ( n , classes [ int ( c ) ] ) , end = ' , ' )
2019-02-08 21:43:05 +00:00
2019-02-11 12:45:04 +00:00
# Draw bounding boxes and labels of detections
2019-04-22 12:59:39 +00:00
for * xyxy , conf , cls_conf , cls in det :
2019-02-10 20:06:22 +00:00
if save_txt : # Write to file
2019-02-11 11:26:30 +00:00
with open ( save_path + ' .txt ' , ' a ' ) as file :
2019-03-30 17:45:04 +00:00
file . write ( ( ' %g ' * 6 + ' \n ' ) % ( * xyxy , cls , conf ) )
2019-02-08 21:43:05 +00:00
2019-02-11 12:45:04 +00:00
# Add bbox to the image
label = ' %s %.2f ' % ( classes [ int ( cls ) ] , conf )
2019-03-25 13:59:38 +00:00
plot_one_box ( xyxy , im0 , label = label , color = colors [ int ( cls ) ] )
2018-08-26 08:51:39 +00:00
2019-03-25 13:59:38 +00:00
print ( ' Done. ( %.3f s) ' % ( time . time ( ) - t ) )
2018-08-26 08:51:39 +00:00
2019-02-11 12:45:04 +00:00
if webcam : # Show live webcam
2019-02-21 22:23:03 +00:00
cv2 . imshow ( weights , im0 )
2019-02-11 12:45:04 +00:00
2019-04-28 21:16:21 +00:00
if save_images : # Save image with detections
if dataloader . mode == ' images ' :
cv2 . imwrite ( save_path , im0 )
else :
2019-04-02 11:43:18 +00:00
if vid_path != save_path : # new video
vid_path = save_path
if isinstance ( vid_writer , cv2 . VideoWriter ) :
vid_writer . release ( ) # release previous video writer
2019-04-28 21:16:21 +00:00
fps = vid_cap . get ( cv2 . CAP_PROP_FPS )
2019-04-02 11:43:18 +00:00
width = int ( vid_cap . get ( cv2 . CAP_PROP_FRAME_WIDTH ) )
height = int ( vid_cap . get ( cv2 . CAP_PROP_FRAME_HEIGHT ) )
2019-05-18 10:47:12 +00:00
vid_writer = cv2 . VideoWriter ( save_path , cv2 . VideoWriter_fourcc ( * fourcc ) , fps , ( width , height ) )
2019-04-02 11:43:18 +00:00
vid_writer . write ( im0 )
2019-04-26 10:16:33 +00:00
if save_images :
2019-04-26 10:14:35 +00:00
print ( ' Results saved to %s ' % os . getcwd ( ) + os . sep + output )
2019-04-26 10:16:33 +00:00
if platform == ' darwin ' : # macos
os . system ( ' open ' + output + ' ' + save_path )
2018-11-21 18:24:00 +00:00
2018-08-26 08:51:39 +00:00
if __name__ == ' __main__ ' :
2018-12-05 13:31:08 +00:00
parser = argparse . ArgumentParser ( )
2019-05-25 12:51:01 +00:00
parser . add_argument ( ' --cfg ' , type = str , default = ' cfg/yolov3-spp.cfg ' , help = ' cfg file path ' )
2019-04-02 14:06:15 +00:00
parser . add_argument ( ' --data-cfg ' , type = str , default = ' data/coco.data ' , help = ' coco.data file path ' )
2019-05-25 12:51:01 +00:00
parser . add_argument ( ' --weights ' , type = str , default = ' weights/yolov3-spp.weights ' , help = ' path to weights file ' )
2019-02-08 22:28:00 +00:00
parser . add_argument ( ' --images ' , type = str , default = ' data/samples ' , help = ' path to images ' )
2019-04-29 15:49:09 +00:00
parser . add_argument ( ' --img-size ' , type = int , default = 416 , help = ' inference size (pixels) ' )
2019-04-02 11:43:18 +00:00
parser . add_argument ( ' --conf-thres ' , type = float , default = 0.5 , help = ' object confidence threshold ' )
2019-03-30 17:45:04 +00:00
parser . add_argument ( ' --nms-thres ' , type = float , default = 0.5 , help = ' iou threshold for non-maximum suppression ' )
2019-05-18 10:47:12 +00:00
parser . add_argument ( ' --fourcc ' , type = str , default = ' mp4v ' , help = ' specifies the fourcc code for output video encoding (make sure ffmpeg supports specified fourcc codec) ' )
parser . add_argument ( ' --output ' , type = str , default = ' output ' , help = ' specifies the output path for images and videos ' )
2018-12-05 13:31:08 +00:00
opt = parser . parse_args ( )
print ( opt )
2019-02-10 20:06:22 +00:00
with torch . no_grad ( ) :
detect (
opt . cfg ,
2019-04-02 11:43:18 +00:00
opt . data_cfg ,
2019-02-10 20:06:22 +00:00
opt . weights ,
2019-04-29 15:49:09 +00:00
images = opt . images ,
2019-02-10 20:06:22 +00:00
img_size = opt . img_size ,
conf_thres = opt . conf_thres ,
2019-05-18 10:47:12 +00:00
nms_thres = opt . nms_thres ,
fourcc = opt . fourcc ,
output = opt . output
2019-02-10 20:06:22 +00:00
)