This commit is contained in:
Glenn Jocher 2019-02-11 12:40:14 +01:00
parent 003daea143
commit ebd682b25c
1 changed files with 7 additions and 8 deletions

View File

@ -24,7 +24,7 @@ class load_images(): # for inference
self.nF = len(self.files) # number of image files self.nF = len(self.files) # number of image files
self.height = img_size self.height = img_size
assert self.nF > 0, 'No images found in path %s' % path assert self.nF > 0, 'No images found in ' + path
def __iter__(self): def __iter__(self):
self.count = -1 self.count = -1
@ -41,7 +41,7 @@ class load_images(): # for inference
assert img0 is not None, 'Failed to load ' + img_path assert img0 is not None, 'Failed to load ' + img_path
# Padded resize # Padded resize
img, ratio, padw, padh = letterbox(img0, height=self.height, color=(127.5, 127.5, 127.5)) img, _, _, _ = letterbox(img0, height=self.height)
# Normalize RGB # Normalize RGB
img = img[:, :, ::-1].transpose(2, 0, 1) img = img[:, :, ::-1].transpose(2, 0, 1)
@ -58,13 +58,12 @@ class load_images(): # for inference
class load_images_and_labels(): # for training class load_images_and_labels(): # for training
def __init__(self, path, batch_size=1, img_size=608, multi_scale=False, augment=False): def __init__(self, path, batch_size=1, img_size=608, multi_scale=False, augment=False):
self.path = path self.path = path
# self.img_files = sorted(glob.glob('%s/*.*' % path))
with open(path, 'r') as file: with open(path, 'r') as file:
self.img_files = file.readlines() self.img_files = file.readlines()
self.img_files = [path.replace('\n', '') for path in self.img_files] self.img_files = [path.replace('\n', '') for path in self.img_files]
self.label_files = [path.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt') for path in self.label_files = [path.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt')
self.img_files] for path in self.img_files]
self.nF = len(self.img_files) # number of image files self.nF = len(self.img_files) # number of image files
self.nB = math.ceil(self.nF / batch_size) # number of batches self.nB = math.ceil(self.nF / batch_size) # number of batches
@ -73,7 +72,7 @@ class load_images_and_labels(): # for training
self.multi_scale = multi_scale self.multi_scale = multi_scale
self.augment = augment self.augment = augment
assert self.nB > 0, 'No images found in path %s' % path assert self.nB > 0, 'No images found in %s' % path
def __iter__(self): def __iter__(self):
self.count = -1 self.count = -1
@ -128,7 +127,7 @@ class load_images_and_labels(): # for training
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)
h, w, _ = img.shape h, w, _ = img.shape
img, ratio, padw, padh = letterbox(img, height=height, color=(127.5, 127.5, 127.5)) img, ratio, padw, padh = letterbox(img, height=height)
# Load labels # Load labels
if os.path.isfile(label_path): if os.path.isfile(label_path):
@ -189,7 +188,7 @@ class load_images_and_labels(): # for training
return self.nB # number of batches return self.nB # number of batches
def letterbox(img, height=416, color=(0, 0, 0)): # resize a rectangular image to a padded square def letterbox(img, height=416, color=(127.5, 127.5, 127.5)): # resize a rectangular image to a padded square
shape = img.shape[:2] # shape = [height, width] shape = img.shape[:2] # shape = [height, width]
ratio = float(height) / max(shape) # ratio = old / new ratio = float(height) / max(shape) # ratio = old / new
new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) new_shape = (round(shape[1] * ratio), round(shape[0] * ratio))