This commit is contained in:
Glenn Jocher 2019-12-30 13:39:25 -08:00
parent 9e58191983
commit cf92235b8d
1 changed files with 4 additions and 5 deletions

View File

@ -743,13 +743,12 @@ def kmean_anchors(path='data/coco64.txt', n=9, img_size=(288, 640)): # from uti
wh = [] wh = []
dataset = LoadImagesAndLabels(path, augment=True, rect=True, cache_labels=True) dataset = LoadImagesAndLabels(path, augment=True, rect=True, cache_labels=True)
for s, l in zip(dataset.shapes, dataset.labels): for s, l in zip(dataset.shapes, dataset.labels):
l = l[:, 3:5] * (s / max(s)) # image normalized to letterbox normalized wh wh.append(l[:, 3:5] * (s / s.max())) # image normalized to letterbox normalized wh
l = l.repeat(10, axis=0) # augment 10x wh = np.concatenate(wh, 0).repeat(10, axis=0) # augment 10x
l *= np.random.uniform(img_size[0], img_size[1], size=(l.shape[0], 1)) # normalized to pixels (multi-scale) wh *= np.random.uniform(img_size[0], img_size[1], size=(wh.shape[0], 1)) # normalized to pixels (multi-scale)
wh.append(l)
wh = np.concatenate(wh, 0) # wh from cxywh
# Kmeans calculation # Kmeans calculation
print('Running kmeans...')
k, dist = cluster.vq.kmeans(wh, n) # points, mean distance k, dist = cluster.vq.kmeans(wh, n) # points, mean distance
k = k[np.argsort(k.prod(1))] # sort small to large k = k[np.argsort(k.prod(1))] # sort small to large