This commit is contained in:
Glenn Jocher 2019-05-05 13:21:37 +02:00
parent dd2d713484
commit 7d857cda95
1 changed files with 7 additions and 7 deletions

View File

@ -130,7 +130,7 @@ class LoadWebcam: # for inference
class LoadImagesAndLabels(Dataset): # for training/testing
def __init__(self, path, img_size=416, batch_size=16, augment=False):
def __init__(self, path, img_size=416, batch_size=16, augment=False, rect=True):
with open(path, 'r') as f:
img_files = f.read().splitlines()
self.img_files = list(filter(lambda x: len(x) > 0, img_files))
@ -144,8 +144,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
for x in self.img_files]
# Rectangular Training https://github.com/ultralytics/yolov3/issues/232
self.train_rectangular = False
if self.train_rectangular:
self.pad_rectangular = rect
if self.pad_rectangular:
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
nb = bi[-1] + 1 # number of batches
from PIL import Image
@ -185,8 +185,8 @@ class LoadImagesAndLabels(Dataset): # for training/testing
try:
with open(file, 'r') as f:
self.labels[i] = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
except: # missing label file
pass
except:
pass # missing label file
def __len__(self):
return len(self.img_files)
@ -195,7 +195,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
img_path = self.img_files[index]
label_path = self.label_files[index]
# if hasattr(self, 'imgs'):
# if hasattr(self, 'imgs'): # preloaded
# img = self.imgs[index] # BGR
img = cv2.imread(img_path) # BGR
assert img is not None, 'File Not Found ' + img_path
@ -219,7 +219,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
# Letterbox
h, w, _ = img.shape
if self.train_rectangular:
if self.pad_rectangular:
new_shape = self.batch_shapes[self.batch[index]]
img, ratio, padw, padh = letterbox(img, new_shape=new_shape, mode='rect')
else: