updates
This commit is contained in:
		
							parent
							
								
									9e97c4cadb
								
							
						
					
					
						commit
						8fac566a87
					
				|  | @ -799,12 +799,13 @@ def kmean_anchors(path='../coco/train2017.txt', n=9, img_size=(320, 640)): | ||||||
|     # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.') |     # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.') | ||||||
| 
 | 
 | ||||||
|     # Evolve |     # Evolve | ||||||
|  |     npr = np.random | ||||||
|     wh = torch.Tensor(wh) |     wh = torch.Tensor(wh) | ||||||
|     f, sh, ng, mp, s = fitness(thr, wh, k), k.shape, 1000, 0.1, 0.3  # fitness, generations, mutation probability, sigma |     f, sh, ng, mp, s = fitness(thr, wh, k), k.shape, 1000, 0.9, 0.1  # fitness, generations, mutation probability, sigma | ||||||
|     for _ in tqdm(range(ng), desc='Evolving anchors'): |     for _ in tqdm(range(ng), desc='Evolving anchors'): | ||||||
|         v = np.ones(sh) |         v = np.ones(sh) | ||||||
|         while (v == 1).all():  # mutate until a change occurs (prevent duplicates) |         while (v == 1).all():  # mutate until a change occurs (prevent duplicates) | ||||||
|             v = ((np.random.random(sh) < mp) * np.random.randn(*sh) * s + 1) ** 2.0 |             v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)  # 98.6, 61.6 | ||||||
|         kg = (k.copy() * v).clip(min=2.0) |         kg = (k.copy() * v).clip(min=2.0) | ||||||
|         fg = fitness(thr, wh, kg) |         fg = fitness(thr, wh, kg) | ||||||
|         if fg > f: |         if fg > f: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue