updates
This commit is contained in:
parent
55077b2770
commit
8f1becd55c
|
@ -53,7 +53,7 @@ def labels_to_class_weights(labels):
|
|||
# Get class weights (inverse frequency) from training labels
|
||||
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
||||
classes = labels[:, 0].astype(np.int)
|
||||
weights = 1 / (np.bincount(classes) + 1e-6) # number of targets per class
|
||||
weights = 1 / (np.bincount(classes, minlength=classes.max() + 1) + 1e-6) # number of targets per class
|
||||
weights /= weights.sum()
|
||||
return torch.Tensor(weights)
|
||||
|
||||
|
|
Loading…
Reference in New Issue