updates
This commit is contained in:
parent
55077b2770
commit
8f1becd55c
|
@ -53,7 +53,7 @@ def labels_to_class_weights(labels):
|
||||||
# Get class weights (inverse frequency) from training labels
|
# Get class weights (inverse frequency) from training labels
|
||||||
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
||||||
classes = labels[:, 0].astype(np.int)
|
classes = labels[:, 0].astype(np.int)
|
||||||
weights = 1 / (np.bincount(classes) + 1e-6) # number of targets per class
|
weights = 1 / (np.bincount(classes, minlength=classes.max() + 1) + 1e-6) # number of targets per class
|
||||||
weights /= weights.sum()
|
weights /= weights.sum()
|
||||||
return torch.Tensor(weights)
|
return torch.Tensor(weights)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue