diff --git a/train.py b/train.py index f30ba066..1b34f62c 100644 --- a/train.py +++ b/train.py @@ -145,6 +145,7 @@ def train( # Start training t, t0 = time.time(), time.time() model.hyp = hyp # attach hyperparameters to model + model.class_weights = labels_to_class_weights(dataset.labels).to(device) # attach class weights model_info(model) nb = len(dataloader) results = (0, 0, 0, 0, 0) # P, R, mAP, F1, test_loss diff --git a/utils/datasets.py b/utils/datasets.py index 4c949c78..2ccd5d43 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -174,9 +174,20 @@ class LoadImagesAndLabels(Dataset): # for training/testing self.batch_shapes = np.ceil(np.array(shapes) * img_size / 32.).astype(np.int) * 32 self.batch = bi # batch index of image + # Preload images # if n < 200: # preload all images into memory if possible # self.imgs = [cv2.imread(img_files[i]) for i in range(n)] + # Preload labels (required for weighted CE training) + self.labels = [np.array([])] * n + iter = tqdm(self.label_files, desc='Reading labels') if n > 5000 else self.label_files + for i, file in enumerate(iter): + try: + with open(file, 'r') as f: + self.labels[i] = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) + except: # missing label file + pass + def __len__(self): return len(self.img_files) @@ -217,10 +228,9 @@ class LoadImagesAndLabels(Dataset): # for training/testing # Load labels labels = [] if os.path.isfile(label_path): - with open(label_path, 'r') as file: - lines = file.read().splitlines() - - x = np.array([x.split() for x in lines], dtype=np.float32) + # with open(label_path, 'r') as f: + # x = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32) + x = self.labels[index] if x.size > 0: # Normalized xywh to pixel xyxy format labels = x.copy() diff --git a/utils/utils.py b/utils/utils.py index 3e66439f..b961cbe6 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -49,6 +49,15 @@ def model_info(model): print('Model Summary: %g layers, %g parameters, %g gradients' % (i + 1, n_p, n_g)) +def labels_to_class_weights(labels): + # Get class weights (inverse frequency) from training labels + labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO + classes = labels[:, 0].astype(np.int) + weights = 1 / (np.bincount(classes) + 1e-6) # number of targets per class + weights /= weights.sum() + return torch.Tensor(weights) + + def coco_class_weights(): # frequency of each class in coco train2014 weights = 1 / torch.FloatTensor( [187437, 4955, 30920, 6033, 3838, 4332, 3160, 7051, 7677, 9167, 1316, 1372, 833, 6757, 7355, 3302, 3776, 4671, @@ -247,7 +256,7 @@ def compute_loss(p, targets, model): # predictions, targets, model # Define criteria MSE = nn.MSELoss() - CE = nn.CrossEntropyLoss() + CE = nn.CrossEntropyLoss(weight=model.class_weights) BCE = nn.BCEWithLogitsLoss() # Compute losses