This commit is contained in:
Glenn Jocher 2019-05-29 18:04:11 +02:00
parent 9cf5ab0c9d
commit 0847334241
1 changed files with 45 additions and 0 deletions

View File

@ -7,6 +7,8 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
from . import torch_utils
@ -490,6 +492,49 @@ def coco_only_people(path='../coco/labels/val2014/'):
print(labels.shape[0], file)
def kmeans_targets(path='./data/coco_64img.txt'): # from utils.utils import *; kmeans_targets()
with open(path, 'r') as f:
img_files = f.read().splitlines()
img_files = list(filter(lambda x: len(x) > 0, img_files))
# Read shapes
n = len(img_files)
assert n > 0, 'No images found in %s' % path
label_files = [x.replace('images', 'labels').
replace('.jpeg', '.txt').
replace('.jpg', '.txt').
replace('.bmp', '.txt').
replace('.png', '.txt') for x in img_files]
s = np.array([Image.open(f).size for f in tqdm(img_files, desc='Reading image shapes')]) # (width, height)
# Read targets
labels = [np.zeros((0, 5))] * n
iter = tqdm(label_files, desc='Reading labels')
for i, file in enumerate(iter):
try:
with open(file, 'r') as f:
l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
if l.shape[0]:
assert l.shape[1] == 5, '> 5 label columns: %s' % file
assert (l >= 0).all(), 'negative labels: %s' % file
assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
l[:, [1, 3]] *= s[i][0]
l[:, [2, 4]] *= s[i][1]
l[:, 1:] *= 320 / max(s[i])
labels[i] = l
except:
pass # print('Warning: missing labels for %s' % self.img_files[i]) # missing label file
assert len(np.concatenate(labels, 0)) > 0, 'No labels found. Incorrect label paths provided.'
# kmeans
from scipy import cluster
wh = np.concatenate(labels, 0)[:, 3:5]
k = cluster.vq.kmeans(wh, 9)[0]
k = k[np.argsort(k.prod(1))]
for x in k.ravel():
print('%.1f, ' % x, end='')
# Plotting functions ---------------------------------------------------------------------------------------------------
def plot_one_box(x, img, color=None, label=None, line_thickness=None):