This commit is contained in:
Glenn Jocher 2020-01-29 14:18:45 -08:00
parent 9e97c4cadb
commit 8fac566a87
1 changed files with 3 additions and 2 deletions

View File

@ -799,12 +799,13 @@ def kmean_anchors(path='../coco/train2017.txt', n=9, img_size=(320, 640)):
# ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.') # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
# Evolve # Evolve
npr = np.random
wh = torch.Tensor(wh) wh = torch.Tensor(wh)
f, sh, ng, mp, s = fitness(thr, wh, k), k.shape, 1000, 0.1, 0.3 # fitness, generations, mutation probability, sigma f, sh, ng, mp, s = fitness(thr, wh, k), k.shape, 1000, 0.9, 0.1 # fitness, generations, mutation probability, sigma
for _ in tqdm(range(ng), desc='Evolving anchors'): for _ in tqdm(range(ng), desc='Evolving anchors'):
v = np.ones(sh) v = np.ones(sh)
while (v == 1).all(): # mutate until a change occurs (prevent duplicates) while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
v = ((np.random.random(sh) < mp) * np.random.randn(*sh) * s + 1) ** 2.0 v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) # 98.6, 61.6
kg = (k.copy() * v).clip(min=2.0) kg = (k.copy() * v).clip(min=2.0)
fg = fitness(thr, wh, kg) fg = fitness(thr, wh, kg)
if fg > f: if fg > f: