From bb209111c44ee7390d57220d4980f62162637f13 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 20 Nov 2019 13:36:15 -0800 Subject: [PATCH] updates --- train.py | 3 +-- utils/utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 6987bdc6..1b1975e0 100644 --- a/train.py +++ b/train.py @@ -204,8 +204,7 @@ def train(): model.nc = nc # attach number of classes to model model.arc = opt.arc # attach yolo architecture model.hyp = hyp # attach hyperparameters to model - if hasattr(dataset, 'labels'): - model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights + model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights torch_utils.model_info(model, report='summary') # 'full' or 'summary' nb = len(dataloader) maps = np.zeros(nc) # mAP per class diff --git a/utils/utils.py b/utils/utils.py index 503cdff2..b7a9d8c5 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -43,14 +43,16 @@ def load_classes(path): def labels_to_class_weights(labels, nc=80): # Get class weights (inverse frequency) from training labels - ni = len(labels) # number of images + if labels[0] is None: # no labels loaded + return None + labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO classes = labels[:, 0].astype(np.int) # labels = [class xywh] weights = np.bincount(classes, minlength=nc) # occurences per class # Prepend gridpoint count (for uCE trianing) # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image - # weights = np.hstack([gpi * ni - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start + # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start weights[weights == 0] = 1 # replace empty bins with 1 weights = 1 / weights # number of targets per class