This commit is contained in:
Glenn Jocher 2019-01-06 15:58:41 +02:00
parent 6e1ff541c9
commit 8dfa653942
2 changed files with 11 additions and 6 deletions

View File

@ -15,7 +15,9 @@ from utils.utils import xyxy2xywh
class load_images(): # for inference
def __init__(self, path, batch_size=1, img_size=416):
if os.path.isdir(path):
image_format = ['.jpg', '.jpeg', '.png', '.tif']
self.files = sorted(glob.glob('%s/*.*' % path))
self.files = list(filter(lambda x: os.path.splitext(x)[1].lower() in image_format, self.files))
elif os.path.isfile(path):
self.files = [path]

View File

@ -42,8 +42,11 @@ def main():
spec = coremltools.utils.convert_neural_network_spec_weights_to_fp16(spec)
yolov3_model = coremltools.models.MLModel(spec)
name_out0 = spec.description.output[0].name
name_out1 = spec.description.output[1].name
num_classes = 80
num_anchors = 507
num_anchors = 507 # 507 for yolov3-tiny,
spec.description.output[0].type.multiArrayType.shape.append(num_anchors)
spec.description.output[0].type.multiArrayType.shape.append(num_classes)
# spec.description.output[0].type.multiArrayType.shape.append(1)
@ -75,7 +78,7 @@ def main():
from PIL import Image
img = Image.open('../yolov3/data/samples/zidane_416.jpg')
out = yolov3_model.predict({'0': img}, useCPUOnly=True)
print(out['148'].shape, out['150'].shape)
print(out[name_out0].shape, out[name_out1].shape)
# 3. Create NMS protobuf
import numpy as np
@ -107,15 +110,15 @@ def main():
del ma_type.shape[:]
nms = nms_spec.nonMaximumSuppression
nms.confidenceInputFeatureName = '148' # 1x507x80
nms.coordinatesInputFeatureName = '150' # 1x507x4
nms.confidenceInputFeatureName = name_out0 # 1x507x80
nms.coordinatesInputFeatureName = name_out1 # 1x507x4
nms.confidenceOutputFeatureName = 'confidence'
nms.coordinatesOutputFeatureName = 'coordinates'
nms.iouThresholdInputFeatureName = 'iouThreshold'
nms.confidenceThresholdInputFeatureName = 'confidenceThreshold'
nms.iouThreshold = 0.6
nms.confidenceThreshold = 0.3
nms.iouThreshold = 0.4
nms.confidenceThreshold = 0.5
nms.pickTop.perClass = True
labels = np.loadtxt('../yolov3/data/coco.names', dtype=str, delimiter='\n')