diff --git a/train.py b/train.py index bd056686..ec861237 100644 --- a/train.py +++ b/train.py @@ -429,15 +429,14 @@ if __name__ == '__main__': if os.path.exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate # Select parent(s) x = np.loadtxt('evolve.txt', ndmin=2) - if len(x) > 1: - parent = 'weighted' # parent selection method: 'single' or 'weighted' - if parent == 'single': - x = x[fitness(x).argmax()] - elif parent == 'weighted': # weighted combination - n = min(10, x.shape[0]) # number to merge - x = x[np.argsort(-fitness(x))][:n] # top n mutations - w = fitness(x) - fitness(x).min() # weights - x = (x[:n] * w.reshape(n, 1)).sum(0) / w.sum() # new parent + parent = 'weighted' # parent selection method: 'single' or 'weighted' + if parent == 'single' or len(x) == 1: + x = x[fitness(x).argmax()] + elif parent == 'weighted': # weighted combination + n = min(10, x.shape[0]) # number to merge + x = x[np.argsort(-fitness(x))][:n] # top n mutations + w = fitness(x) - fitness(x).min() # weights + x = (x[:n] * w.reshape(n, 1)).sum(0) / w.sum() # new parent for i, k in enumerate(hyp.keys()): hyp[k] = x[i + 7]