diff --git a/utils/utils.py b/utils/utils.py index 4f787fc5..08b93792 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -7,6 +7,8 @@ import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn +from PIL import Image +from tqdm import tqdm from . import torch_utils @@ -490,6 +492,49 @@ def coco_only_people(path='../coco/labels/val2014/'): print(labels.shape[0], file) +def kmeans_targets(path='./data/coco_64img.txt'): # from utils.utils import *; kmeans_targets() + with open(path, 'r') as f: + img_files = f.read().splitlines() + img_files = list(filter(lambda x: len(x) > 0, img_files)) + + # Read shapes + n = len(img_files) + assert n > 0, 'No images found in %s' % path + label_files = [x.replace('images', 'labels'). + replace('.jpeg', '.txt'). + replace('.jpg', '.txt'). + replace('.bmp', '.txt'). + replace('.png', '.txt') for x in img_files] + s = np.array([Image.open(f).size for f in tqdm(img_files, desc='Reading image shapes')]) # (width, height) + + # Read targets + labels = [np.zeros((0, 5))] * n + iter = tqdm(label_files, desc='Reading labels') + for i, file in enumerate(iter): + try: + with open(file, 'r') as f: + l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) + if l.shape[0]: + assert l.shape[1] == 5, '> 5 label columns: %s' % file + assert (l >= 0).all(), 'negative labels: %s' % file + assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file + l[:, [1, 3]] *= s[i][0] + l[:, [2, 4]] *= s[i][1] + l[:, 1:] *= 320 / max(s[i]) + labels[i] = l + except: + pass # print('Warning: missing labels for %s' % self.img_files[i]) # missing label file + assert len(np.concatenate(labels, 0)) > 0, 'No labels found. Incorrect label paths provided.' + + # kmeans + from scipy import cluster + wh = np.concatenate(labels, 0)[:, 3:5] + k = cluster.vq.kmeans(wh, 9)[0] + k = k[np.argsort(k.prod(1))] + for x in k.ravel(): + print('%.1f, ' % x, end='') + + # Plotting functions --------------------------------------------------------------------------------------------------- def plot_one_box(x, img, color=None, label=None, line_thickness=None):