diff --git a/utils/utils.py b/utils/utils.py index a27cad38..59fe1a8c 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -53,7 +53,10 @@ def labels_to_class_weights(labels, nc=80): # Get class weights (inverse frequency) from training labels labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO classes = labels[:, 0].astype(np.int) - weights = 1 / (np.bincount(classes, minlength=nc) + 1e-6) # number of targets per class + n = np.bincount(classes, minlength=nc) + weights = np.zeros(nc) + i = n.nonzero() + weights[i] = 1 / n[i] # number of targets per class weights /= weights.sum() return torch.Tensor(weights)