updates
This commit is contained in:
parent
acaab77b7a
commit
1e3fb6566c
|
@ -53,7 +53,10 @@ def labels_to_class_weights(labels, nc=80):
|
||||||
# Get class weights (inverse frequency) from training labels
|
# Get class weights (inverse frequency) from training labels
|
||||||
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
||||||
classes = labels[:, 0].astype(np.int)
|
classes = labels[:, 0].astype(np.int)
|
||||||
weights = 1 / (np.bincount(classes, minlength=nc) + 1e-6) # number of targets per class
|
n = np.bincount(classes, minlength=nc)
|
||||||
|
weights = np.zeros(nc)
|
||||||
|
i = n.nonzero()
|
||||||
|
weights[i] = 1 / n[i] # number of targets per class
|
||||||
weights /= weights.sum()
|
weights /= weights.sum()
|
||||||
return torch.Tensor(weights)
|
return torch.Tensor(weights)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue