remove label loading during training
This commit is contained in:
parent
b3e1d74478
commit
6c5ecaf805
|
@ -414,9 +414,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
if self.image_weights:
|
||||
index = self.indices[index]
|
||||
|
||||
img_path = self.img_files[index]
|
||||
label_path = self.label_files[index]
|
||||
|
||||
hyp = self.hyp
|
||||
if self.mosaic:
|
||||
# Load mosaic
|
||||
|
@ -434,19 +431,14 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
|
||||
# Load labels
|
||||
labels = []
|
||||
if os.path.isfile(label_path):
|
||||
x = self.labels[index]
|
||||
if x is None: # labels not preloaded
|
||||
with open(label_path, 'r') as f:
|
||||
x = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
|
||||
|
||||
if x.size > 0:
|
||||
# Normalized xywh to pixel xyxy format
|
||||
labels = x.copy()
|
||||
labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
|
||||
labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
|
||||
labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
|
||||
labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
|
||||
x = self.labels[index]
|
||||
if x is not None and x.size > 0:
|
||||
# Normalized xywh to pixel xyxy format
|
||||
labels = x.copy()
|
||||
labels[:, 1] = ratio[0] * w * (x[:, 1] - x[:, 3] / 2) + pad[0] # pad width
|
||||
labels[:, 2] = ratio[1] * h * (x[:, 2] - x[:, 4] / 2) + pad[1] # pad height
|
||||
labels[:, 3] = ratio[0] * w * (x[:, 1] + x[:, 3] / 2) + pad[0]
|
||||
labels[:, 4] = ratio[1] * h * (x[:, 2] + x[:, 4] / 2) + pad[1]
|
||||
|
||||
if self.augment:
|
||||
# Augment imagespace
|
||||
|
@ -496,7 +488,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
|
|||
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
return torch.from_numpy(img), labels_out, img_path, shapes
|
||||
return torch.from_numpy(img), labels_out, self.img_files[index], shapes
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
|
|
Loading…
Reference in New Issue