This commit is contained in:
Glenn Jocher 2019-08-06 17:24:30 +02:00
parent 141032045b
commit 01bc76faeb
1 changed files with 13 additions and 4 deletions

View File

@ -247,6 +247,19 @@ class LoadImagesAndLabels(Dataset): # for training/testing
pass # print('Warning: missing labels for %s' % self.img_files[i]) # missing label file
assert len(np.concatenate(self.labels, 0)) > 0, 'No labels found. Incorrect label paths provided.'
# Cache images into memory for faster training (~5GB)
cache_images = False
if cache_images and augment: # if training
for i in tqdm(range(min(len(self.img_files), 10000)), desc='Caching images'): # max 10k images
img_path = self.img_files[i]
img = cv2.imread(img_path) # BGR
assert img is not None, 'Image Not Found ' + img_path
r = self.img_size / max(img.shape) # size ratio
if self.augment and r < 1: # if training (NOT testing), downsize to inference shape
h, w, _ = img.shape
img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_LINEAR) # or INTER_AREA
self.imgs[i] = img
# Detect corrupted images https://medium.com/joelthchao/programmatically-detect-corrupted-image-8c1b2006c3d3
detect_corrupted_images = False
if detect_corrupted_images:
@ -284,9 +297,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
h, w, _ = img.shape
img = cv2.resize(img, (int(w * r), int(h * r)), interpolation=cv2.INTER_LINEAR) # INTER_LINEAR fastest
if index < 5000: # cache first 5000 images into memory (~5GB)
self.imgs[index] = img
# Augment colorspace
augment_hsv = True
if self.augment and augment_hsv:
@ -320,7 +330,6 @@ class LoadImagesAndLabels(Dataset): # for training/testing
if x is None: # labels not preloaded
with open(label_path, 'r') as f:
x = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
self.labels[index] = x # save for next time
if x.size > 0:
# Normalized xywh to pixel xyxy format