add generations arg to kmeans()

This commit is contained in:
Glenn Jocher 2020-04-12 12:49:23 -07:00
parent f65a50d13d
commit efc754a794
1 changed files with 4 additions and 3 deletions

View File

@ -679,11 +679,12 @@ def coco_single_class_labels(path='../coco/labels/train2014/', label_class=43):
shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images
def kmean_anchors(path='../coco/train2017.txt', n=12, img_size=(320, 1024), thr=0.10): def kmean_anchors(path='../coco/train2017.txt', n=12, img_size=(320, 1024), thr=0.10, gen=1000):
# Creates kmeans anchors for use in *.cfg files: from utils.utils import *; _ = kmean_anchors() # Creates kmeans anchors for use in *.cfg files: from utils.utils import *; _ = kmean_anchors()
# n: number of anchors # n: number of anchors
# img_size: (min, max) image size used for multi-scale training (can be same values) # img_size: (min, max) image size used for multi-scale training (can be same values)
# thr: IoU threshold hyperparameter used for training (0.0 - 1.0) # thr: IoU threshold hyperparameter used for training (0.0 - 1.0)
# gen: generations to evolve anchors using genetic algorithm
from utils.datasets import LoadImagesAndLabels from utils.datasets import LoadImagesAndLabels
def print_results(k): def print_results(k):
@ -742,8 +743,8 @@ def kmean_anchors(path='../coco/train2017.txt', n=12, img_size=(320, 1024), thr=
# Evolve # Evolve
npr = np.random npr = np.random
f, sh, ng, mp, s = fitness(k), k.shape, 1000, 0.9, 0.1 # fitness, generations, mutation prob, sigma f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
for _ in tqdm(range(ng), desc='Evolving anchors'): for _ in tqdm(range(gen), desc='Evolving anchors'):
v = np.ones(sh) v = np.ones(sh)
while (v == 1).all(): # mutate until a change occurs (prevent duplicates) while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) # 98.6, 61.6 v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) # 98.6, 61.6