This commit is contained in:
Glenn Jocher 2018-11-21 19:10:10 +01:00
parent b0b19b3b94
commit 7283f52d0c
1 changed files with 10 additions and 8 deletions

View File

@ -7,6 +7,7 @@ from utils.utils import *
cuda = torch.cuda.is_available() cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if cuda else 'cpu') device = torch.device('cuda:0' if cuda else 'cpu')
f_path = os.path.dirname(os.path.realpath(__file__)) + '/'
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Get data configuration # Get data configuration
@ -16,8 +17,8 @@ parser.add_argument('-output_folder', type=str, default='output', help='path to
parser.add_argument('-plot_flag', type=bool, default=True) parser.add_argument('-plot_flag', type=bool, default=True)
parser.add_argument('-txt_out', type=bool, default=False) parser.add_argument('-txt_out', type=bool, default=False)
parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path') parser.add_argument('-cfg', type=str, default=f_path + 'cfg/yolov3.cfg', help='cfg file path')
parser.add_argument('-class_path', type=str, default='data/coco.names', help='path to class label file') parser.add_argument('-class_path', type=str, default=f_path + 'data/coco.names', help='path to class label file')
parser.add_argument('-conf_thres', type=float, default=0.50, help='object confidence threshold') parser.add_argument('-conf_thres', type=float, default=0.50, help='object confidence threshold')
parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression') parser.add_argument('-nms_thres', type=float, default=0.45, help='iou threshold for non-maximum suppression')
parser.add_argument('-batch_size', type=int, default=1, help='size of the batches') parser.add_argument('-batch_size', type=int, default=1, help='size of the batches')
@ -25,6 +26,7 @@ parser.add_argument('-img_size', type=int, default=32 * 13, help='size of each i
opt = parser.parse_args() opt = parser.parse_args()
print(opt) print(opt)
def main(opt): def main(opt):
os.system('rm -rf ' + opt.output_folder) os.system('rm -rf ' + opt.output_folder)
os.makedirs(opt.output_folder, exist_ok=True) os.makedirs(opt.output_folder, exist_ok=True)
@ -32,12 +34,12 @@ def main(opt):
# Load model # Load model
model = Darknet(opt.cfg, opt.img_size) model = Darknet(opt.cfg, opt.img_size)
weights_path = 'weights/yolov3.pt' weights_path = f_path + 'weights/yolov3.pt'
if weights_path.endswith('.weights'): # saved in darknet format if weights_path.endswith('.weights'): # saved in darknet format
load_weights(model, weights_path) load_weights(model, weights_path)
else: # endswith('.pt'), saved in pytorch format else: # endswith('.pt'), saved in pytorch format
if weights_path == 'weights/yolov3.pt' and not os.path.isfile(weights_path): if weights_path.endswith('weights/yolov3.pt') and not os.path.isfile(weights_path):
os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -P weights') os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights_path)
checkpoint = torch.load(weights_path, map_location='cpu') checkpoint = torch.load(weights_path, map_location='cpu')
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
@ -63,8 +65,8 @@ def main(opt):
imgs = [] # Stores image paths imgs = [] # Stores image paths
img_detections = [] # Stores detections for each image index img_detections = [] # Stores detections for each image index
prev_time = time.time() prev_time = time.time()
for batch_i, (img_paths, img) in enumerate(dataloader): for i, (img_paths, img) in enumerate(dataloader):
print(batch_i, img.shape, end=' ') print('%g/%g' % (i + 1, len(dataloader)), end=' ')
# Get detections # Get detections
with torch.no_grad(): with torch.no_grad():
@ -76,7 +78,7 @@ def main(opt):
img_detections.extend(detections) img_detections.extend(detections)
imgs.extend(img_paths) imgs.extend(img_paths)
print('Batch %d... (Done %.3f s)' % (batch_i, time.time() - prev_time)) print('Batch %d... Done. (%.3fs)' % (i, time.time() - prev_time))
prev_time = time.time() prev_time = time.time()
# Bounding-box colors # Bounding-box colors