diff --git a/test.py b/test.py index 7d9f8822..ec9d717d 100644 --- a/test.py +++ b/test.py @@ -44,7 +44,7 @@ def test( names = load_classes(data_cfg['names']) # class names # Dataloader - dataset = LoadImagesAndLabels(test_path, img_size=img_size) + dataset = LoadImagesAndLabels(test_path, img_size, batch_size) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, diff --git a/train.py b/train.py index af10b5f8..3a2bae2c 100644 --- a/train.py +++ b/train.py @@ -119,7 +119,7 @@ def train( # plt.savefig('LR.png', dpi=300) # Dataset - dataset = LoadImagesAndLabels(train_path, img_size=img_size, augment=True) + dataset = LoadImagesAndLabels(train_path, img_size, batch_size, augment=True) # Initialize distributed training if torch.cuda.device_count() > 1: @@ -131,7 +131,7 @@ def train( dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=opt.num_workers, - shuffle=True, + shuffle=False, pin_memory=True, collate_fn=dataset.collate_fn) diff --git a/utils/datasets.py b/utils/datasets.py index 1dc1b7b8..5fa6e182 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -74,7 +74,7 @@ class LoadImages: # for inference print('image %g/%g %s: ' % (self.count, self.nF, path), end='') # Padded resize - img, _, _, _ = letterbox(img0, height=self.height) + img, _, _, _ = letterbox(img0, new_shape=self.height) print('%gx%g ' % img.shape[:2], end='') # print image size # Normalize RGB @@ -116,7 +116,7 @@ class LoadWebcam: # for inference img0 = cv2.flip(img0, 1) # flip left-right # Padded resize - img, _, _, _ = letterbox(img0, height=self.height) + img, _, _, _ = letterbox(img0, new_shape=self.height) # Normalize RGB img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB @@ -130,7 +130,7 @@ class LoadWebcam: # for inference class LoadImagesAndLabels(Dataset): # for training/testing - def __init__(self, path, img_size=416, augment=False): + def __init__(self, path, img_size=416, batch_size=16, augment=False): with open(path, 'r') as f: img_files = f.read().splitlines() self.img_files = list(filter(lambda x: len(x) > 0, img_files)) @@ -143,17 +143,35 @@ class LoadImagesAndLabels(Dataset): # for training/testing x.replace('images', 'labels').replace('.bmp', '.txt').replace('.jpg', '.txt').replace('.png', '.txt') for x in self.img_files] - # sort dataset by aspect ratio for rectangular training - self.rectangle = False - if self.rectangle: + # Rectangular Training https://github.com/ultralytics/yolov3/issues/232 + self.train_rectangular = True + if self.train_rectangular: + bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index + nb = bi[-1] # number of batches from PIL import Image + # Read image aspect ratios s = np.array([Image.open(f).size for f in tqdm(self.img_files, desc='Reading image shapes')]) ar = s[:, 1] / s[:, 0] # aspect ratio + + # Sort by aspect ratio i = ar.argsort() + ar = ar[i] self.img_files = [self.img_files[i] for i in i] self.label_files = [self.label_files[i] for i in i] - self.ar = ar[i] + + # Set training image shapes + shapes = [[1, 1]] * nb + for i in range(nb): + ari = ar[bi == i] + mini, maxi = ari.min(), ari.max() + if maxi < 1: + shapes[i] = [maxi, 1] + elif mini > 1: + shapes[i] = [1, 1 / mini] + + self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32.).astype(np.int) * 32 + self.batch = bi # batch index of image # if n < 200: # preload all images into memory if possible # self.imgs = [cv2.imread(img_files[i]) for i in range(n)] @@ -187,8 +205,13 @@ class LoadImagesAndLabels(Dataset): # for training/testing img_hsv[:, :, 2] = V if b < 1 else V.clip(None, 255) cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) + # Letterbox h, w, _ = img.shape - img, ratio, padw, padh = letterbox(img, height=self.img_size, mode='square') + if self.train_rectangular: + new_shape = self.batch_shapes[self.batch[index]] + img, ratio, padw, padh = letterbox(img, new_shape=new_shape, mode='rect') + else: + img, ratio, padw, padh = letterbox(img, new_shape=self.img_size, mode='square') # Load labels labels = [] @@ -248,23 +271,30 @@ class LoadImagesAndLabels(Dataset): # for training/testing return torch.stack(img, 0), torch.cat(label, 0), path, hw -def letterbox(img, height=416, color=(127.5, 127.5, 127.5), mode='rect'): +def letterbox(img, new_shape=416, color=(127.5, 127.5, 127.5), mode='auto'): # Resize a rectangular image to a 32 pixel multiple rectangle - shape = img.shape[:2] # shape = [height, width] - ratio = float(height) / max(shape) # ratio = old / new - new_shape = (round(shape[1] * ratio), round(shape[0] * ratio)) # new_shape = [width, height] + # https://github.com/ultralytics/yolov3/issues/232 + shape = img.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + ratio = float(new_shape) / max(shape) + else: + ratio = max(new_shape) / max(shape) # ratio = new / old + new_unpad = (int(round(shape[1] * ratio)), int(round(shape[0] * ratio))) - # Select padding https://github.com/ultralytics/yolov3/issues/232 - if mode is 'rect': # rectangle - dw = np.mod(height - new_shape[0], 32) / 2 # width padding - dh = np.mod(height - new_shape[1], 32) / 2 # height padding + # Compute padding https://github.com/ultralytics/yolov3/issues/232 + if mode is 'auto': # minimum rectangle + dw = np.mod(new_shape - new_unpad[0], 32) / 2 # width padding + dh = np.mod(new_shape - new_unpad[1], 32) / 2 # height padding elif mode is 'square': # square - dw = (height - new_shape[0]) / 2 # width padding - dh = (height - new_shape[1]) / 2 # height padding + dw = (new_shape - new_unpad[0]) / 2 # width padding + dh = (new_shape - new_unpad[1]) / 2 # height padding + elif mode is 'rect': # square + dw = (new_shape[1] - new_unpad[0]) / 2 # width padding + dh = (new_shape[0] - new_unpad[1]) / 2 # height padding top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) - img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_AREA) # resized, no border img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # padded square return img, ratio, dw, dh