From 8e327e3bd08bf1bf60925d936b2e3e874e0db435 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 20 Nov 2019 13:33:25 -0800 Subject: [PATCH] updates --- train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 1b1975e0..6987bdc6 100644 --- a/train.py +++ b/train.py @@ -204,7 +204,8 @@ def train(): model.nc = nc # attach number of classes to model model.arc = opt.arc # attach yolo architecture model.hyp = hyp # attach hyperparameters to model - model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights + if hasattr(dataset, 'labels'): + model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights torch_utils.model_info(model, report='summary') # 'full' or 'summary' nb = len(dataloader) maps = np.zeros(nc) # mAP per class