car-detection-bayes/our_scripts/config_bayes.py

70 lines
2.8 KiB
Python

import yaml
class Configuration:
class Train:
class OtherHyps:
def __init__(self, config_file) -> None:
for key, value in config_file['train']['other-hyps'].items():
self.__dict__[key] = value
def __init__(self, config_file) -> None:
self.other_hyps = Configuration.Train.OtherHyps(config_file)
for key, value in config_file['train'].items():
if key != 'other-hyps':
self.__dict__[key] = value
class Experiments:
def __init__(self, config_file) -> None:
for key, value in config_file['experiments'].items():
self.__dict__[key] = value
class Detect:
def __init__(self, config_file) -> None:
for key, value in config_file['detect'].items():
self.__dict__[key] = value
class ConfussionMatrix:
def __init__(self, config_file) -> None:
for key, value in config_file['confussion-matrix'].items():
self.__dict__[key] = value
class Bayes:
def __init__(self, config_file) -> None:
for key, value in config_file['bayes'].items():
self.__dict__[key] = value
def __init__(self, config_path) -> None:
self.config_path = config_path
file = yaml.load(open(config_path, 'r'), Loader=yaml.Loader)
self.train = self.Train(file)
self.experiments = self.Experiments(file)
self.detect = self.Detect(file)
self.confussion_matrix = self.ConfussionMatrix(file)
self.bayes = self.Bayes(file)
def get_bayes_bounds(self) -> list:
result = []
dicts = {**self.train.__dict__, **self.train.other_hyps.__dict__, **self.detect.__dict__}
for key, value in dicts.items():
if type(value) not in [None, Configuration.Train.OtherHyps] and type(value) == dict:
if value['type'] == 'continuous': # continous value
val = (value['min'], value['max'])
item = {'name': key, 'type': value['type'], 'domain': val}
elif value['type'] == 'discrete' and 'step' in value:
val = tuple(n for n in range(value['min'], value['max'], value['step']))
item = {'name': key, 'type': value['type'], 'domain': val}
elif value['type'] == 'discrete': # discrete values without step
val = tuple(n for n in value['values'])
item = {'name': key, 'type': value['type'], 'domain': val}
else: # unknown type
raise Exception("Invalid type", value['type'])
result.append(item)
return result
if __name__ == '__main__':
config = Configuration()
print(config)