remove label loading during training

This commit is contained in:
Glenn Jocher 2020-04-07 16:57:22 -07:00
parent b3e1d74478
commit 6c5ecaf805
1 changed files with 9 additions and 17 deletions

View File

@ -414,9 +414,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
if self.image_weights: if self.image_weights:
index = self.indices[index] index = self.indices[index]
img_path = self.img_files[index]
label_path = self.label_files[index]
hyp = self.hyp hyp = self.hyp
if self.mosaic: if self.mosaic:
# Load mosaic # Load mosaic
@ -434,19 +431,14 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# Load labels # Load labels
labels = [] labels = []
if os.path.isfile(label_path): x = self.labels[index]
x = self.labels[index] if x is not None and x.size > 0:
if x is None: # labels not preloaded # Normalized xywh to pixel xyxy format
with open(label_path, 'r') as f: labels = x.copy()
x = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
if x.size > 0: labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
# Normalized xywh to pixel xyxy format labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
labels = x.copy()
labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
if self.augment: if self.augment:
# Augment imagespace # Augment imagespace
@ -496,7 +488,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
return torch.from_numpy(img), labels_out, img_path, shapes return torch.from_numpy(img), labels_out, self.img_files[index], shapes
@staticmethod @staticmethod
def collate_fn(batch): def collate_fn(batch):