diff --git a/utils/datasets.py b/utils/datasets.py index 2103e45e..a0b17421 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -247,6 +247,19 @@ class LoadImagesAndLabels(Dataset): # for training/testing pass # print('Warning: missing labels for %s' % self.img_files[i]) # missing label file assert len(np.concatenate(self.labels, 0)) > 0, 'No labels found. Incorrect label paths provided.' + # Cache images into memory for faster training (~5GB) + cache_images = False + if cache_images and augment: # if training + for i in tqdm(range(min(len(self.img_files), 10000)), desc='Caching images'): # max 10k images + img_path = self.img_files[i] + img = cv2.imread(img_path) # BGR + assert img is not None, 'Image Not Found ' + img_path + r = self.img_size / max(img.shape) # size ratio + if self.augment and r < 1: # if training (NOT testing), downsize to inference shape + h, w, _ = img.shape + img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_LINEAR) # or INTER_AREA + self.imgs[i] = img + # Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3 detect_corrupted_images = False if detect_corrupted_images: @@ -284,9 +297,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing h, w, _ = img.shape img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_LINEAR) # INTER_LINEAR fastest - if index < 5000: # cache first 5000 images into memory (~5GB) - self.imgs[index] = img - # Augment colorspace augment_hsv = True if self.augment and augment_hsv: @@ -320,7 +330,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing if x is None: # labels not preloaded with open(label_path, 'r') as f: x = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) - self.labels[index] = x # save for next time if x.size > 0: # Normalized xywh to pixel xyxy format