This commit is contained in:
Glenn Jocher 2019-02-11 12:40:14 +01:00
parent 003daea143
commit ebd682b25c
1 changed files with 7 additions and 8 deletions

View File

@ -24,7 +24,7 @@ class load_images(): # for inference
self.nF = len(self.files) # number of image files
self.height = img_size
assert self.nF > 0, 'No images found in path %s' % path
assert self.nF > 0, 'No images found in ' + path
def __iter__(self):
self.count = -1
@ -41,7 +41,7 @@ class load_images(): # for inference
assert img0 is not None, 'Failed to load ' + img_path
# Padded resize
img, ratio, padw, padh = letterbox(img0, height=self.height, color=(127.5, 127.5, 127.5))
img, _, _, _ = letterbox(img0, height=self.height)
# Normalize RGB
img = img[:, :, ::-1].transpose(2, 0, 1)
@ -58,13 +58,12 @@ class load_images(): # for inference
class load_images_and_labels(): # for training
def __init__(self, path, batch_size=1, img_size=608, multi_scale=False, augment=False):
self.path = path
# self.img_files = sorted(glob.glob('%s/*.*' % path))
with open(path, 'r') as file:
self.img_files = file.readlines()
self.img_files = [path.replace('\n', '') for path in self.img_files]
self.label_files = [path.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt') for path in
self.img_files]
self.label_files = [path.replace('images', 'labels').replace('.png', '.txt').replace('.jpg', '.txt')
for path in self.img_files]
self.nF = len(self.img_files) # number of image files
self.nB = math.ceil(self.nF / batch_size) # number of batches
@ -73,7 +72,7 @@ class load_images_and_labels(): # for training
self.multi_scale = multi_scale
self.augment = augment
assert self.nB > 0, 'No images found in path %s' % path
assert self.nB > 0, 'No images found in %s' % path
def __iter__(self):
self.count = -1
@ -128,7 +127,7 @@ class load_images_and_labels(): # for training
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)
h, w, _ = img.shape
img, ratio, padw, padh = letterbox(img, height=height, color=(127.5, 127.5, 127.5))
img, ratio, padw, padh = letterbox(img, height=height)
# Load labels
if os.path.isfile(label_path):
@ -189,7 +188,7 @@ class load_images_and_labels(): # for training
return self.nB # number of batches
def letterbox(img, height=416, color=(0, 0, 0)): # resize a rectangular image to a padded square
def letterbox(img, height=416, color=(127.5, 127.5, 127.5)): # resize a rectangular image to a padded square
shape = img.shape[:2] # shape = [height, width]
ratio = float(height) / max(shape) # ratio = old / new
new_shape = (round(shape[1] * ratio), round(shape[0] * ratio))