This commit is contained in:
Glenn Jocher 2019-08-17 14:09:38 +02:00
parent a1200ef130
commit b8c870711f
1 changed files with 6 additions and 0 deletions

View File

@ -56,9 +56,15 @@ def model_info(model, report='summary'):
def labels_to_class_weights(labels, nc=80): def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels # Get class weights (inverse frequency) from training labels
ni = len(labels) # number of images
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
classes = labels[:, 0].astype(np.int) # labels = [class xywh] classes = labels[:, 0].astype(np.int) # labels = [class xywh]
weights = np.bincount(classes, minlength=nc) # occurences per class weights = np.bincount(classes, minlength=nc) # occurences per class
# Prepend gridpoint count (for uCE trianing)
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
# weights = np.hstack([gpi * ni, weights]) # prepend gridpoints to start
weights[weights == 0] = 1 # replace empty bins with 1 weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class weights = 1 / weights # number of targets per class
weights /= weights.sum() # normalize weights /= weights.sum() # normalize