updates
This commit is contained in:
parent
327aaebd7c
commit
ca67e2353b
|
@ -92,34 +92,39 @@ class LoadWebcam: # for inference
|
||||||
class LoadImagesAndLabels: # for training
|
class LoadImagesAndLabels: # for training
|
||||||
def __init__(self, path, batch_size=1, img_size=608, augment=False):
|
def __init__(self, path, batch_size=1, img_size=608, augment=False):
|
||||||
with open(path, 'r') as file:
|
with open(path, 'r') as file:
|
||||||
self.img_files = file.readlines()
|
self.img_files = file.read().splitlines()
|
||||||
self.img_files = [x.replace('\n', '') for x in self.img_files]
|
|
||||||
self.img_files = list(filter(lambda x: len(x) > 0, self.img_files))
|
self.img_files = list(filter(lambda x: len(x) > 0, self.img_files))
|
||||||
|
|
||||||
|
self.nF = len(self.img_files) # number of image files
|
||||||
|
self.nB = math.ceil(self.nF / batch_size) # number of batches
|
||||||
|
assert self.nF > 0, 'No images found in %s' % path
|
||||||
|
|
||||||
self.label_files = [x.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt')
|
self.label_files = [x.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt')
|
||||||
for x in self.img_files]
|
for x in self.img_files]
|
||||||
|
|
||||||
self.nF = len(self.img_files) # number of image files
|
|
||||||
self.nB = math.ceil(self.nF / batch_size) # number of batches
|
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.img_size = img_size
|
self.img_size = img_size
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
|
iter(self)
|
||||||
assert self.nF > 0, 'No images found in %s' % path
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
self.count = -1
|
self.count = -1
|
||||||
self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
|
self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.load_images(index, index + 1)
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
self.count += 1
|
self.count += 1 # batches
|
||||||
if self.count == self.nB:
|
if self.count >= self.nB:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
|
|
||||||
ia = self.count * self.batch_size
|
ia = self.count * self.batch_size # start index
|
||||||
ib = min((self.count + 1) * self.batch_size, self.nF)
|
ib = min(ia + self.batch_size, self.nF) # end index
|
||||||
|
return self.load_images(ia, ib)
|
||||||
|
|
||||||
|
def load_images(self, ia, ib):
|
||||||
img_all, labels_all, img_paths, img_shapes = [], [], [], []
|
img_all, labels_all, img_paths, img_shapes = [], [], [], []
|
||||||
for index, files_index in enumerate(range(ia, ib)):
|
for index, files_index in enumerate(range(ia, ib)):
|
||||||
img_path = self.img_files[self.shuffled_vector[files_index]]
|
img_path = self.img_files[self.shuffled_vector[files_index]]
|
||||||
|
|
Loading…
Reference in New Issue