diff --git a/utils/datasets.py b/utils/datasets.py index e2b03c9a..6da5eff8 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -317,16 +317,16 @@ class LoadImagesAndLabels(Dataset): # for training/testing # Cache labels self.imgs = [None] * n - create_datasubset, extract_bounding_boxes = False, False + self.labels = [np.zeros((0, 5), dtype=np.float32)] * n + create_datasubset, extract_bounding_boxes, labels_loaded = False, False, False nm, nf, ne, ns, nd = 0, 0, 0, 0, 0 # number missing, found, empty, datasubset, duplicate np_labels_path = str(Path(self.label_files[0]).parent) + '.npy' # saved labels in *.npy file if os.path.isfile(np_labels_path): print('Loading labels from %s' % np_labels_path) - self.labels = list(np.load(np_labels_path, allow_pickle=True)) - labels_loaded = True - else: - self.labels = [np.zeros((0, 5), dtype=np.float32)] * n - labels_loaded = False + x = list(np.load(np_labels_path, allow_pickle=True)) + if len(x) == n: + self.labels = x + labels_loaded = True pbar = tqdm(self.label_files, desc='Caching labels') for i, file in enumerate(pbar):