This commit is contained in:
Glenn Jocher 2019-04-18 14:47:05 +02:00
parent 2089e4f4c8
commit 286257c5ac
1 changed files with 12 additions and 4 deletions

View File

@ -131,15 +131,20 @@ class LoadWebcam: # for inference
class LoadImagesAndLabels(Dataset): # for training/testing class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, augment=False): def __init__(self, path, img_size=416, augment=False):
with open(path, 'r') as file: with open(path, 'r') as file:
self.img_files = file.read().splitlines() img_files = file.read().splitlines()
self.img_files = list(filter(lambda x: len(x) > 0, self.img_files)) self.img_files = list(filter(lambda x: len(x) > 0, img_files))
assert len(self.img_files) > 0, 'No images found in %s' % path
n = len(self.img_files)
assert n > 0, 'No images found in %s' % path
self.img_size = img_size self.img_size = img_size
self.augment = augment self.augment = augment
self.label_files = [ self.label_files = [
x.replace('images', 'labels').replace('.bmp', '.txt').replace('.jpg', '.txt').replace('.png', '.txt') x.replace('images', 'labels').replace('.bmp', '.txt').replace('.jpg', '.txt').replace('.png', '.txt')
for x in self.img_files] for x in self.img_files]
if n < 200: # preload all images into memory if possible
self.imgs = (cv2.imread(img_files[i]) for i in range(n))
def __len__(self): def __len__(self):
return len(self.img_files) return len(self.img_files)
@ -147,6 +152,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
img_path = self.img_files[index] img_path = self.img_files[index]
label_path = self.label_files[index] label_path = self.label_files[index]
if hasattr(self, 'imgs'):
img = self.imgs[index] # BGR
else:
img = cv2.imread(img_path) # BGR img = cv2.imread(img_path) # BGR
assert img is not None, 'File Not Found ' + img_path assert img is not None, 'File Not Found ' + img_path