This commit is contained in:
Glenn Jocher 2019-04-27 17:49:22 +02:00
parent 55077b2770
commit 8f1becd55c
1 changed files with 1 additions and 1 deletions

View File

@ -53,7 +53,7 @@ def labels_to_class_weights(labels):
# Get class weights (inverse frequency) from training labels # Get class weights (inverse frequency) from training labels
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
classes = labels[:, 0].astype(np.int) classes = labels[:, 0].astype(np.int)
weights = 1 / (np.bincount(classes) + 1e-6) # number of targets per class weights = 1 / (np.bincount(classes, minlength=classes.max() + 1) + 1e-6) # number of targets per class
weights /= weights.sum() weights /= weights.sum()
return torch.Tensor(weights) return torch.Tensor(weights)