diff --git a/utils/datasets.py b/utils/datasets.py index 35f7c702..ff1a0a11 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -131,15 +131,20 @@ class LoadWebcam: # for inference class LoadImagesAndLabels(Dataset): # for training/testing def __init__(self, path, img_size=416, augment=False): with open(path, 'r') as file: - self.img_files = file.read().splitlines() - self.img_files = list(filter(lambda x: len(x) > 0, self.img_files)) - assert len(self.img_files) > 0, 'No images found in %s' % path + img_files = file.read().splitlines() + self.img_files = list(filter(lambda x: len(x) > 0, img_files)) + + n = len(self.img_files) + assert n > 0, 'No images found in %s' % path self.img_size = img_size self.augment = augment self.label_files = [ x.replace('images', 'labels').replace('.bmp', '.txt').replace('.jpg', '.txt').replace('.png', '.txt') for x in self.img_files] + if n < 200: # preload all images into memory if possible + self.imgs = (cv2.imread(img_files[i]) for i in range(n)) + def __len__(self): return len(self.img_files) @@ -147,7 +152,10 @@ class LoadImagesAndLabels(Dataset): # for training/testing img_path = self.img_files[index] label_path = self.label_files[index] - img = cv2.imread(img_path) # BGR + if hasattr(self, 'imgs'): + img = self.imgs[index] # BGR + else: + img = cv2.imread(img_path) # BGR assert img is not None, 'File Not Found ' + img_path augment_hsv = True