updates
This commit is contained in:
parent
6e1ff541c9
commit
8dfa653942
|
@ -15,7 +15,9 @@ from utils.utils import xyxy2xywh
|
|||
class load_images(): # for inference
|
||||
def __init__(self, path, batch_size=1, img_size=416):
|
||||
if os.path.isdir(path):
|
||||
image_format = ['.jpg', '.jpeg', '.png', '.tif']
|
||||
self.files = sorted(glob.glob('%s/*.*' % path))
|
||||
self.files = list(filter(lambda x: os.path.splitext(x)[1].lower() in image_format, self.files))
|
||||
elif os.path.isfile(path):
|
||||
self.files = [path]
|
||||
|
||||
|
|
|
@ -42,8 +42,11 @@ def main():
|
|||
spec = coremltools.utils.convert_neural_network_spec_weights_to_fp16(spec)
|
||||
yolov3_model = coremltools.models.MLModel(spec)
|
||||
|
||||
name_out0 = spec.description.output[0].name
|
||||
name_out1 = spec.description.output[1].name
|
||||
|
||||
num_classes = 80
|
||||
num_anchors = 507
|
||||
num_anchors = 507 # 507 for yolov3-tiny,
|
||||
spec.description.output[0].type.multiArrayType.shape.append(num_anchors)
|
||||
spec.description.output[0].type.multiArrayType.shape.append(num_classes)
|
||||
# spec.description.output[0].type.multiArrayType.shape.append(1)
|
||||
|
@ -75,7 +78,7 @@ def main():
|
|||
from PIL import Image
|
||||
img = Image.open('../yolov3/data/samples/zidane_416.jpg')
|
||||
out = yolov3_model.predict({'0': img}, useCPUOnly=True)
|
||||
print(out['148'].shape, out['150'].shape)
|
||||
print(out[name_out0].shape, out[name_out1].shape)
|
||||
|
||||
# 3. Create NMS protobuf
|
||||
import numpy as np
|
||||
|
@ -107,15 +110,15 @@ def main():
|
|||
del ma_type.shape[:]
|
||||
|
||||
nms = nms_spec.nonMaximumSuppression
|
||||
nms.confidenceInputFeatureName = '148' # 1x507x80
|
||||
nms.coordinatesInputFeatureName = '150' # 1x507x4
|
||||
nms.confidenceInputFeatureName = name_out0 # 1x507x80
|
||||
nms.coordinatesInputFeatureName = name_out1 # 1x507x4
|
||||
nms.confidenceOutputFeatureName = 'confidence'
|
||||
nms.coordinatesOutputFeatureName = 'coordinates'
|
||||
nms.iouThresholdInputFeatureName = 'iouThreshold'
|
||||
nms.confidenceThresholdInputFeatureName = 'confidenceThreshold'
|
||||
|
||||
nms.iouThreshold = 0.6
|
||||
nms.confidenceThreshold = 0.3
|
||||
nms.iouThreshold = 0.4
|
||||
nms.confidenceThreshold = 0.5
|
||||
nms.pickTop.perClass = True
|
||||
|
||||
labels = np.loadtxt('../yolov3/data/coco.names', dtype=str, delimiter='\n')
|
||||
|
|
Loading…
Reference in New Issue