updates
This commit is contained in:
parent
b0b19b3b94
commit
7283f52d0c
18
detect.py
18
detect.py
|
@ -7,6 +7,7 @@ from utils.utils import *
|
||||||
|
|
||||||
cuda = torch.cuda.is_available()
|
cuda = torch.cuda.is_available()
|
||||||
device = torch.device('cuda:0' if cuda else 'cpu')
|
device = torch.device('cuda:0' if cuda else 'cpu')
|
||||||
|
f_path = os.path.dirname(os.path.realpath(__file__)) + '/'
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
# Get data configuration
|
# Get data configuration
|
||||||
|
@ -16,8 +17,8 @@ parser.add_argument('-output_folder', type=str, default='output', help='path to
|
||||||
parser.add_argument('-plot_flag', type=bool, default=True)
|
parser.add_argument('-plot_flag', type=bool, default=True)
|
||||||
parser.add_argument('-txt_out', type=bool, default=False)
|
parser.add_argument('-txt_out', type=bool, default=False)
|
||||||
|
|
||||||
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
|
parser.add_argument('-cfg', type=str, default=f_path + 'cfg/yolov3.cfg', help='cfg file path')
|
||||||
parser.add_argument('-class_path', type=str, default='data/coco.names', help='path to class label file')
|
parser.add_argument('-class_path', type=str, default=f_path + 'data/coco.names', help='path to class label file')
|
||||||
parser.add_argument('-conf_thres', type=float, default=0.50, help='object confidence threshold')
|
parser.add_argument('-conf_thres', type=float, default=0.50, help='object confidence threshold')
|
||||||
parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression')
|
parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression')
|
||||||
parser.add_argument('-batch_size', type=int, default=1, help='size of the batches')
|
parser.add_argument('-batch_size', type=int, default=1, help='size of the batches')
|
||||||
|
@ -25,6 +26,7 @@ parser.add_argument('-img_size', type=int, default=32 * 13, help='size of each i
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt)
|
print(opt)
|
||||||
|
|
||||||
|
|
||||||
def main(opt):
|
def main(opt):
|
||||||
os.system('rm -rf ' + opt.output_folder)
|
os.system('rm -rf ' + opt.output_folder)
|
||||||
os.makedirs(opt.output_folder, exist_ok=True)
|
os.makedirs(opt.output_folder, exist_ok=True)
|
||||||
|
@ -32,12 +34,12 @@ def main(opt):
|
||||||
# Load model
|
# Load model
|
||||||
model = Darknet(opt.cfg, opt.img_size)
|
model = Darknet(opt.cfg, opt.img_size)
|
||||||
|
|
||||||
weights_path = 'weights/yolov3.pt'
|
weights_path = f_path + 'weights/yolov3.pt'
|
||||||
if weights_path.endswith('.weights'): # saved in darknet format
|
if weights_path.endswith('.weights'): # saved in darknet format
|
||||||
load_weights(model, weights_path)
|
load_weights(model, weights_path)
|
||||||
else: # endswith('.pt'), saved in pytorch format
|
else: # endswith('.pt'), saved in pytorch format
|
||||||
if weights_path == 'weights/yolov3.pt' and not os.path.isfile(weights_path):
|
if weights_path.endswith('weights/yolov3.pt') and not os.path.isfile(weights_path):
|
||||||
os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -P weights')
|
os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights_path)
|
||||||
|
|
||||||
checkpoint = torch.load(weights_path, map_location='cpu')
|
checkpoint = torch.load(weights_path, map_location='cpu')
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint['model'])
|
||||||
|
@ -63,8 +65,8 @@ def main(opt):
|
||||||
imgs = [] # Stores image paths
|
imgs = [] # Stores image paths
|
||||||
img_detections = [] # Stores detections for each image index
|
img_detections = [] # Stores detections for each image index
|
||||||
prev_time = time.time()
|
prev_time = time.time()
|
||||||
for batch_i, (img_paths, img) in enumerate(dataloader):
|
for i, (img_paths, img) in enumerate(dataloader):
|
||||||
print(batch_i, img.shape, end=' ')
|
print('%g/%g' % (i + 1, len(dataloader)), end=' ')
|
||||||
|
|
||||||
# Get detections
|
# Get detections
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -76,7 +78,7 @@ def main(opt):
|
||||||
img_detections.extend(detections)
|
img_detections.extend(detections)
|
||||||
imgs.extend(img_paths)
|
imgs.extend(img_paths)
|
||||||
|
|
||||||
print('Batch %d... (Done %.3f s)' % (batch_i, time.time() - prev_time))
|
print('Batch %d... Done. (%.3fs)' % (i, time.time() - prev_time))
|
||||||
prev_time = time.time()
|
prev_time = time.time()
|
||||||
|
|
||||||
# Bounding-box colors
|
# Bounding-box colors
|
||||||
|
|
Loading…
Reference in New Issue