This commit is contained in:
Glenn Jocher 2019-03-21 02:28:30 +02:00
parent fc87e9af1f
commit 2cd6805063
1 changed files with 7 additions and 1 deletions

View File

@ -113,7 +113,12 @@ class LoadImagesAndLabels: # for training
return self return self
def __getitem__(self, index): def __getitem__(self, index):
return self.load_images(index, index + 1) imgs, labels0, img_paths, img_shapes = self.load_images(index, index + 1)
labels0[:,0] = index % self.batch_size
labels = torch.zeros(100, 6)
labels[:min(len(labels0), 100)] = labels0 # max 100 labels per image
return imgs.squeeze(0), labels, img_paths, img_shapes
def __next__(self): def __next__(self):
self.count += 1 # batches self.count += 1 # batches
@ -122,6 +127,7 @@ class LoadImagesAndLabels: # for training
ia = self.count * self.batch_size # start index ia = self.count * self.batch_size # start index
ib = min(ia + self.batch_size, self.nF) # end index ib = min(ia + self.batch_size, self.nF) # end index
return self.load_images(ia, ib) return self.load_images(ia, ib)
def load_images(self, ia, ib): def load_images(self, ia, ib):