diff --git a/utils/datasets.py b/utils/datasets.py index 0cd86dd2..6ca270d1 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -92,34 +92,39 @@ class LoadWebcam: # for inference class LoadImagesAndLabels: # for training def __init__(self, path, batch_size=1, img_size=608, augment=False): with open(path, 'r') as file: - self.img_files = file.readlines() - self.img_files = [x.replace('\n', '') for x in self.img_files] + self.img_files = file.read().splitlines() self.img_files = list(filter(lambda x: len(x) > 0, self.img_files)) + self.nF = len(self.img_files) # number of image files + self.nB = math.ceil(self.nF / batch_size) # number of batches + assert self.nF > 0, 'No images found in %s' % path + self.label_files = [x.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt') for x in self.img_files] - self.nF = len(self.img_files) # number of image files - self.nB = math.ceil(self.nF / batch_size) # number of batches self.batch_size = batch_size self.img_size = img_size self.augment = augment - - assert self.nF > 0, 'No images found in %s' % path + iter(self) def __iter__(self): self.count = -1 self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF) return self + def __getitem__(self, index): + return self.load_images(index, index + 1) + def __next__(self): - self.count += 1 - if self.count == self.nB: + self.count += 1 # batches + if self.count >= self.nB: raise StopIteration - ia = self.count * self.batch_size - ib = min((self.count + 1) * self.batch_size, self.nF) + ia = self.count * self.batch_size # start index + ib = min(ia + self.batch_size, self.nF) # end index + return self.load_images(ia, ib) + def load_images(self, ia, ib): img_all, labels_all, img_paths, img_shapes = [], [], [], [] for index, files_index in enumerate(range(ia, ib)): img_path = self.img_files[self.shuffled_vector[files_index]]