This commit is contained in:
Glenn Jocher 2019-11-20 12:05:40 -08:00
parent 429d44282c
commit e58f0a68b6
2 changed files with 3 additions and 3 deletions

View File

@ -204,7 +204,7 @@ def train():
model.nc = nc # attach number of classes to model
model.arc = opt.arc # attach yolo architecture
model.hyp = hyp # attach hyperparameters to model
# model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
torch_utils.model_info(model, report='summary') # 'full' or 'summary'
nb = len(dataloader)
maps = np.zeros(nc) # mAP per class

View File

@ -49,8 +49,8 @@ def labels_to_class_weights(labels, nc=80):
weights = np.bincount(classes, minlength=nc) # occurences per class
# Prepend gridpoint count (for uCE trianing)
gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
weights = np.hstack([gpi * ni - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
# weights = np.hstack([gpi * ni - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class