updates
This commit is contained in:
parent
2089e4f4c8
commit
286257c5ac
|
@ -131,15 +131,20 @@ class LoadWebcam: # for inference
|
|||
class LoadImagesAndLabels(Dataset): # for training/testing
|
||||
def __init__(self, path, img_size=416, augment=False):
|
||||
with open(path, 'r') as file:
|
||||
self.img_files = file.read().splitlines()
|
||||
self.img_files = list(filter(lambda x: len(x) > 0, self.img_files))
|
||||
assert len(self.img_files) > 0, 'No images found in %s' % path
|
||||
img_files = file.read().splitlines()
|
||||
self.img_files = list(filter(lambda x: len(x) > 0, img_files))
|
||||
|
||||
n = len(self.img_files)
|
||||
assert n > 0, 'No images found in %s' % path
|
||||
self.img_size = img_size
|
||||
self.augment = augment
|
||||
self.label_files = [
|
||||
x.replace('images', 'labels').replace('.bmp', '.txt').replace('.jpg', '.txt').replace('.png', '.txt')
|
||||
for x in self.img_files]
|
||||
|
||||
if n < 200: # preload all images into memory if possible
|
||||
self.imgs = (cv2.imread(img_files[i]) for i in range(n))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_files)
|
||||
|
||||
|
@ -147,6 +152,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
img_path = self.img_files[index]
|
||||
label_path = self.label_files[index]
|
||||
|
||||
if hasattr(self, 'imgs'):
|
||||
img = self.imgs[index] # BGR
|
||||
else:
|
||||
img = cv2.imread(img_path) # BGR
|
||||
assert img is not None, 'File Not Found ' + img_path
|
||||
|
||||
|
|
Loading…
Reference in New Issue