add generations arg to kmeans()
This commit is contained in:
parent
f65a50d13d
commit
efc754a794
|
@ -679,11 +679,12 @@ def coco_single_class_labels(path='../coco/labels/train2014/', label_class=43):
|
||||||
shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images
|
shutil.copyfile(src=img_file, dst='new/images/' + Path(file).name.replace('txt', 'jpg')) # copy images
|
||||||
|
|
||||||
|
|
||||||
def kmean_anchors(path='../coco/train2017.txt', n=12, img_size=(320, 1024), thr=0.10):
|
def kmean_anchors(path='../coco/train2017.txt', n=12, img_size=(320, 1024), thr=0.10, gen=1000):
|
||||||
# Creates kmeans anchors for use in *.cfg files: from utils.utils import *; _ = kmean_anchors()
|
# Creates kmeans anchors for use in *.cfg files: from utils.utils import *; _ = kmean_anchors()
|
||||||
# n: number of anchors
|
# n: number of anchors
|
||||||
# img_size: (min, max) image size used for multi-scale training (can be same values)
|
# img_size: (min, max) image size used for multi-scale training (can be same values)
|
||||||
# thr: IoU threshold hyperparameter used for training (0.0 - 1.0)
|
# thr: IoU threshold hyperparameter used for training (0.0 - 1.0)
|
||||||
|
# gen: generations to evolve anchors using genetic algorithm
|
||||||
from utils.datasets import LoadImagesAndLabels
|
from utils.datasets import LoadImagesAndLabels
|
||||||
|
|
||||||
def print_results(k):
|
def print_results(k):
|
||||||
|
@ -742,8 +743,8 @@ def kmean_anchors(path='../coco/train2017.txt', n=12, img_size=(320, 1024), thr=
|
||||||
|
|
||||||
# Evolve
|
# Evolve
|
||||||
npr = np.random
|
npr = np.random
|
||||||
f, sh, ng, mp, s = fitness(k), k.shape, 1000, 0.9, 0.1 # fitness, generations, mutation prob, sigma
|
f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
|
||||||
for _ in tqdm(range(ng), desc='Evolving anchors'):
|
for _ in tqdm(range(gen), desc='Evolving anchors'):
|
||||||
v = np.ones(sh)
|
v = np.ones(sh)
|
||||||
while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
|
while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
|
||||||
v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) # 98.6, 61.6
|
v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) # 98.6, 61.6
|
||||||
|
|
Loading…
Reference in New Issue