This commit is contained in:
Glenn Jocher 2019-03-21 01:57:16 +02:00
parent 327aaebd7c
commit ca67e2353b
1 changed files with 15 additions and 10 deletions

View File

@ -92,34 +92,39 @@ class LoadWebcam: # for inference
class LoadImagesAndLabels: # for training
def __init__(self, path, batch_size=1, img_size=608, augment=False):
with open(path, 'r') as file:
self.img_files = file.readlines()
self.img_files = [x.replace('\n', '') for x in self.img_files]
self.img_files = file.read().splitlines()
self.img_files = list(filter(lambda x: len(x) > 0, self.img_files))
self.nF = len(self.img_files) # number of image files
self.nB = math.ceil(self.nF / batch_size) # number of batches
assert self.nF > 0, 'No images found in %s' % path
self.label_files = [x.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt')
for x in self.img_files]
self.nF = len(self.img_files) # number of image files
self.nB = math.ceil(self.nF / batch_size) # number of batches
self.batch_size = batch_size
self.img_size = img_size
self.augment = augment
assert self.nF > 0, 'No images found in %s' % path
iter(self)
def __iter__(self):
self.count = -1
self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
return self
def __getitem__(self, index):
return self.load_images(index, index + 1)
def __next__(self):
self.count += 1
if self.count == self.nB:
self.count += 1 # batches
if self.count >= self.nB:
raise StopIteration
ia = self.count * self.batch_size
ib = min((self.count + 1) * self.batch_size, self.nF)
ia = self.count * self.batch_size # start index
ib = min(ia + self.batch_size, self.nF) # end index
return self.load_images(ia, ib)
def load_images(self, ia, ib):
img_all, labels_all, img_paths, img_shapes = [], [], [], []
for index, files_index in enumerate(range(ia, ib)):
img_path = self.img_files[self.shuffled_vector[files_index]]