This commit is contained in:
Glenn Jocher 2019-04-27 18:15:27 +02:00
parent acaab77b7a
commit 1e3fb6566c
1 changed files with 4 additions and 1 deletions

View File

@ -53,7 +53,10 @@ def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels # Get class weights (inverse frequency) from training labels
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
classes = labels[:, 0].astype(np.int) classes = labels[:, 0].astype(np.int)
weights = 1 / (np.bincount(classes, minlength=nc) + 1e-6) # number of targets per class n = np.bincount(classes, minlength=nc)
weights = np.zeros(nc)
i = n.nonzero()
weights[i] = 1 / n[i] # number of targets per class
weights /= weights.sum() weights /= weights.sum()
return torch.Tensor(weights) return torch.Tensor(weights)