From 01dbdc45d7e7b8a66c251993c19513891461c439 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 14 Jan 2020 22:22:24 -0800 Subject: [PATCH] updates --- train.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index dbb389da..c754bca4 100644 --- a/train.py +++ b/train.py @@ -457,20 +457,21 @@ if __name__ == '__main__': x = (x * w.reshape(n, 1)).sum(0) / w.sum() # new parent # Mutate - mutate_version = 2 - np.random.seed(int(time.time())) + method = 2 s = 0.2 # 20% sigma + np.random.seed(int(time.time())) g = np.array([1, 1, 1, 1, 1, 1, 1, 0, .1, 1, 0, 1, 1, 1, 1, 1, 1, 1]) # gains ng = len(g) - if mutate_version == 1: - s *= np.random.random() # sigma - v = (np.random.randn(ng) * g * s + 1) ** 2.0 # plt.hist(x.ravel(), 300) - else: + if method == 1: + v = (np.random.randn(ng) * np.random.random() * g * s + 1) ** 2.0 + elif method == 2: + v = (np.random.randn(ng) * np.random.random(ng) * g * s + 1) ** 2.0 + elif method == 3: v = np.ones(ng) while all(v == 1): # mutate untill a change occurs (prevent duplicates) r = (np.random.random(ng) < 0.1) * np.random.randn(ng) # 10% mutation probability - v = (g * s * r + 1) ** 2.0 # plt.hist(x.ravel(), 300) - for i, k in enumerate(hyp.keys()): + v = (g * s * r + 1) ** 2.0 + for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300) hyp[k] = x[i + 7] * v[i] # mutate # Clip to limits