diff --git a/train.py b/train.py index 3c43f0d7..4535c0d5 100644 --- a/train.py +++ b/train.py @@ -446,13 +446,13 @@ if __name__ == '__main__': for _ in range(1): # generations to evolve if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate # Select parent(s) - x = np.loadtxt('evolve.txt', ndmin=2) parent = 'single' # parent selection method: 'single' or 'weighted' + x = np.loadtxt('evolve.txt', ndmin=2) + n = min(3, len(x)) # number of previous results to consider + x = x[np.argsort(-fitness(x))][:n] # top n mutations if parent == 'single' or len(x) == 1: - x = x[fitness(x).argmax()] + x = x[random.randint(0, n - 1)] # select one of the top n elif parent == 'weighted': # weighted combination - n = min(10, len(x)) # number to merge - x = x[np.argsort(-fitness(x))][:n] # top n mutations w = fitness(x) - fitness(x).min() # weights x = (x * w.reshape(n, 1)).sum(0) / w.sum() # new parent