This commit is contained in:
Glenn Jocher 2019-04-27 21:38:20 +02:00
parent ccfd44c2f8
commit 7652365b28
2 changed files with 8 additions and 9 deletions

View File

@ -180,7 +180,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# Preload labels (required for weighted CE training)
self.labels = [np.array([])] * n
iter = tqdm(self.label_files, desc='Reading labels') if n > 5000 else self.label_files
iter = tqdm(self.label_files, desc='Reading labels') if n > 1000 else self.label_files
for i, file in enumerate(iter):
try:
with open(file, 'r') as f:

View File

@ -52,12 +52,11 @@ def model_info(model):
def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
classes = labels[:, 0].astype(np.int)
n = np.bincount(classes, minlength=nc)
weights = np.zeros(nc)
i = n.nonzero()
weights[i] = 1 / n[i] # number of targets per class
weights /= weights.sum()
classes = labels[:, 0].astype(np.int) # labels = [class xywh]
weights = np.bincount(classes, minlength=nc) # occurences per class
weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class
weights /= weights.sum() # normalize
return torch.Tensor(weights)
@ -527,7 +526,7 @@ def plot_images(imgs, targets, fname='images.jpg'):
plt.close()
def plot_results(start=0, stop=0): # from utils.utils import *; plot_results()
def plot_results(start=1, stop=0): # from utils.utils import *; plot_results()
# Plot training results files 'results*.txt'
# import os; os.system('wget https://storage.googleapis.com/ultralytics/yolov3/results_v3.txt')
@ -542,6 +541,6 @@ def plot_results(start=0, stop=0): # from utils.utils import *; plot_results()
for i in range(10):
ax[i].plot(x, results[i, x], marker='.', label=f.replace('.txt', ''))
ax[i].set_title(s[i])
ax[0].legend()
fig.tight_layout()
ax[4].legend()
fig.savefig('results.png', dpi=300)