label *.npy saving for faster caching

This commit is contained in:
Glenn Jocher 2020-05-20 21:13:41 -07:00
parent cd5f6227d9
commit 3ddaf3b63c
1 changed files with 23 additions and 10 deletions

View File

@ -317,12 +317,22 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# Cache labels
self.imgs = [None] * n
self.labels = [np.zeros((0, 5), dtype=np.float32)] * n
extract_bounding_boxes = False
create_datasubset = False
pbar = tqdm(self.label_files, desc='Caching labels')
create_datasubset, extract_bounding_boxes = False, False
nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate
np_labels_path = str(Path(self.label_files[0]).parent) + '.npy' # saved labels in *.npy file
if os.path.isfile(np_labels_path):
print('Loading labels from %s' % np_labels_path)
self.labels = list(np.load(np_labels_path, allow_pickle=True))
labels_loaded = True
else:
self.labels = [np.zeros((0, 5), dtype=np.float32)] * n
labels_loaded = False
pbar = tqdm(self.label_files, desc='Caching labels')
for i, file in enumerate(pbar):
if labels_loaded:
l = self.labels[i]
else:
try:
with open(file, 'r') as f:
l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
@ -378,6 +388,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
pbar.desc = 'Caching labels (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
nf, nm, ne, nd, n)
assert nf > 0, 'No labels found in %s. See %s' % (os.path.dirname(file) + os.sep, help_url)
if not labels_loaded:
print('Saving labels to %s for faster future loading' % np_labels_path)
np.save(np_labels_path, self.labels) # save for next time
# Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
if cache_images: # if training