add generations arg to kmeans()
This commit is contained in:
		
							parent
							
								
									f65a50d13d
								
							
						
					
					
						commit
						efc754a794
					
				|  | @ -679,11 +679,12 @@ def coco_single_class_labels(path='../coco/labels/train2014/', label_class=43): | ||||||
|             shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg'))  # copy images |             shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg'))  # copy images | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def kmean_anchors(path='../coco/train2017.txt', n=12, img_size=(320, 1024), thr=0.10): | def kmean_anchors(path='../coco/train2017.txt', n=12, img_size=(320, 1024), thr=0.10, gen=1000): | ||||||
|     # Creates kmeans anchors for use in *.cfg files: from utils.utils import *; _ = kmean_anchors() |     # Creates kmeans anchors for use in *.cfg files: from utils.utils import *; _ = kmean_anchors() | ||||||
|     # n: number of anchors |     # n: number of anchors | ||||||
|     # img_size: (min, max) image size used for multi-scale training (can be same values) |     # img_size: (min, max) image size used for multi-scale training (can be same values) | ||||||
|     # thr: IoU threshold hyperparameter used for training (0.0 - 1.0) |     # thr: IoU threshold hyperparameter used for training (0.0 - 1.0) | ||||||
|  |     # gen: generations to evolve anchors using genetic algorithm | ||||||
|     from utils.datasets import LoadImagesAndLabels |     from utils.datasets import LoadImagesAndLabels | ||||||
| 
 | 
 | ||||||
|     def print_results(k): |     def print_results(k): | ||||||
|  | @ -742,8 +743,8 @@ def kmean_anchors(path='../coco/train2017.txt', n=12, img_size=(320, 1024), thr= | ||||||
| 
 | 
 | ||||||
|     # Evolve |     # Evolve | ||||||
|     npr = np.random |     npr = np.random | ||||||
|     f, sh, ng, mp, s = fitness(k), k.shape, 1000, 0.9, 0.1  # fitness, generations, mutation prob, sigma |     f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1  # fitness, generations, mutation prob, sigma | ||||||
|     for _ in tqdm(range(ng), desc='Evolving anchors'): |     for _ in tqdm(range(gen), desc='Evolving anchors'): | ||||||
|         v = np.ones(sh) |         v = np.ones(sh) | ||||||
|         while (v == 1).all():  # mutate until a change occurs (prevent duplicates) |         while (v == 1).all():  # mutate until a change occurs (prevent duplicates) | ||||||
|             v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)  # 98.6, 61.6 |             v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)  # 98.6, 61.6 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue