This commit is contained in:
Glenn Jocher 2019-07-30 15:58:10 +02:00
parent 8a74a683ae
commit 65abb1c82f
1 changed files with 3 additions and 3 deletions

View File

@ -157,7 +157,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False): def __init__(self, path, img_size=416, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False):
path = str(Path(path)) # os-agnostic path = str(Path(path)) # os-agnostic
with open(path, 'r') as f: with open(path, 'r') as f:
self.img_files = [x for x in f.read().splitlines() if os.path.splitext(x)[-1].lower() in img_formats] self.img_files = [x.replace('/', os.sep) for x in f.read().splitlines() # os-agnostic
if os.path.splitext(x)[-1].lower() in img_formats]
n = len(self.img_files) n = len(self.img_files)
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
@ -214,9 +215,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
preload_labels = False preload_labels = False
if preload_labels: if preload_labels:
self.labels = [np.zeros((0, 5))] * n self.labels = [np.zeros((0, 5))] * n
iter = tqdm(self.label_files, desc='Reading labels') if n > 10 else self.label_files
extract_bounding_boxes = False extract_bounding_boxes = False
for i, file in enumerate(iter): for i, file in enumerate(tqdm(self.label_files, desc='Reading labels') if n > 10 else self.label_files):
try: try:
with open(file, 'r') as f: with open(file, 'r') as f:
l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)