This commit is contained in:
Glenn Jocher 2019-05-29 18:04:11 +02:00
parent 9cf5ab0c9d
commit 0847334241
1 changed files with 45 additions and 0 deletions

View File

@ -7,6 +7,8 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image
from tqdm import tqdm
from . import torch_utils from . import torch_utils
@ -490,6 +492,49 @@ def coco_only_people(path='../coco/labels/val2014/'):
print(labels.shape[0], file) print(labels.shape[0], file)
def kmeans_targets(path='./data/coco_64img.txt'): # from utils.utils import *; kmeans_targets()
with open(path, 'r') as f:
img_files = f.read().splitlines()
img_files = list(filter(lambda x: len(x) > 0, img_files))
# Read shapes
n = len(img_files)
assert n > 0, 'No images found in %s' % path
label_files = [x.replace('images', 'labels').
replace('.jpeg', '.txt').
replace('.jpg', '.txt').
replace('.bmp', '.txt').
replace('.png', '.txt') for x in img_files]
s = np.array([Image.open(f).size for f in tqdm(img_files, desc='Reading image shapes')]) # (width, height)
# Read targets
labels = [np.zeros((0, 5))] * n
iter = tqdm(label_files, desc='Reading labels')
for i, file in enumerate(iter):
try:
with open(file, 'r') as f:
l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
if l.shape[0]:
assert l.shape[1] == 5, '> 5 label columns: %s' % file
assert (l >= 0).all(), 'negative labels: %s' % file
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
l[:, [1, 3]] *= s[i][0]
l[:, [2, 4]] *= s[i][1]
l[:, 1:] *= 320 / max(s[i])
labels[i] = l
except:
pass # print('Warning: missing labels for %s' % self.img_files[i]) # missing label file
assert len(np.concatenate(labels, 0)) > 0, 'No labels found. Incorrect label paths provided.'
# kmeans
from scipy import cluster
wh = np.concatenate(labels, 0)[:, 3:5]
k = cluster.vq.kmeans(wh, 9)[0]
k = k[np.argsort(k.prod(1))]
for x in k.ravel():
print('%.1f, ' % x, end='')
# Plotting functions --------------------------------------------------------------------------------------------------- # Plotting functions ---------------------------------------------------------------------------------------------------
def plot_one_box(x, img, color=None, label=None, line_thickness=None): def plot_one_box(x, img, color=None, label=None, line_thickness=None):