This commit is contained in:
Glenn Jocher 2019-07-25 13:23:39 +02:00
parent 7b6cba86ef
commit a834377122
1 changed files with 8 additions and 7 deletions

View File

@ -543,19 +543,20 @@ def select_best_evolve(path='evolve*.txt'): # from utils.utils import *; select
print(file, x[fitness.argmax()]) print(file, x[fitness.argmax()])
def kmeans_targets(path='./data/coco_64img.txt'): # from utils.utils import *; kmeans_targets() def kmeans_targets(path='./data/coco_64img.txt', n=9, img_size=320): # from utils.utils import *; kmeans_targets()
# Produces a list of target kmeans suitable for use in *.cfg files
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif'] img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif']
with open(path, 'r') as f: with open(path, 'r') as f:
img_files = [x for x in f.read().splitlines() if os.path.splitext(x)[-1].lower() in img_formats] img_files = [x for x in f.read().splitlines() if os.path.splitext(x)[-1].lower() in img_formats]
# Read shapes # Read shapes
n = len(img_files) nf = len(img_files)
assert n > 0, 'No images found in %s' % path assert nf > 0, 'No images found in %s' % path
label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in img_files] label_files = [x.replace('images', 'labels').replace(os.path.splitext(x)[-1], '.txt') for x in img_files]
s = np.array([Image.open(f).size for f in tqdm(img_files, desc='Reading image shapes')]) # (width, height) s = np.array([Image.open(f).size for f in tqdm(img_files, desc='Reading image shapes')]) # (width, height)
# Read targets # Read targets
labels = [np.zeros((0, 5))] * n labels = [np.zeros((0, 5))] * nf
iter = tqdm(label_files, desc='Reading labels') iter = tqdm(label_files, desc='Reading labels')
for i, file in enumerate(iter): for i, file in enumerate(iter):
try: try:
@ -567,16 +568,16 @@ def kmeans_targets(path='./data/coco_64img.txt'): # from utils.utils import *;
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
l[:, [1, 3]] *= s[i][0] l[:, [1, 3]] *= s[i][0]
l[:, [2, 4]] *= s[i][1] l[:, [2, 4]] *= s[i][1]
l[:, 1:] *= 320 / max(s[i]) l[:, 1:] *= img_size / max(s[i]) # nominal img_size for training here
labels[i] = l labels[i] = l
except: except:
pass # print('Warning: missing labels for %s' % self.img_files[i]) # missing label file pass # print('Warning: missing labels for %s' % self.img_files[i]) # missing label file
assert len(np.concatenate(labels, 0)) > 0, 'No labels found. Incorrect label paths provided.' assert len(np.concatenate(labels, 0)) > 0, 'No labels found. Incorrect label paths provided.'
# kmeans # kmeans calculation
from scipy import cluster from scipy import cluster
wh = np.concatenate(labels, 0)[:, 3:5] wh = np.concatenate(labels, 0)[:, 3:5]
k = cluster.vq.kmeans(wh, 9)[0] k = cluster.vq.kmeans(wh, n)[0]
k = k[np.argsort(k.prod(1))] k = k[np.argsort(k.prod(1))]
for x in k.ravel(): for x in k.ravel():
print('%.1f, ' % x, end='') print('%.1f, ' % x, end='')