From bbe22dd7b4f6057a76291fb923abae052a8678c6 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 23 Aug 2019 17:37:29 +0200 Subject: [PATCH] updates --- train.py | 3 ++- utils/utils.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 6cb2b135..e444f05a 100644 --- a/train.py +++ b/train.py @@ -191,6 +191,7 @@ def train(): # Start training model.nc = nc # attach number of classes to model + model.arc = opt.arc # attach yolo architecture model.hyp = hyp # attach hyperparameters to model model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights model_info(model, report='summary') # 'full' or 'summary' @@ -259,7 +260,7 @@ def train(): pred = model(imgs) # Compute loss - loss, loss_items = compute_loss(pred, targets, model, arc=opt.arc) + loss, loss_items = compute_loss(pred, targets, model) if torch.isnan(loss): print('WARNING: nan loss detected, ending training') return results diff --git a/utils/utils.py b/utils/utils.py index 0b28d7c9..5b307753 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -312,11 +312,12 @@ class FocalLoss(nn.Module): return loss -def compute_loss(p, targets, model, arc='default'): # predictions, targets, model +def compute_loss(p, targets, model): # predictions, targets, model ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor lcls, lbox, lobj = ft([0]), ft([0]), ft([0]) tcls, tbox, indices, anchor_vec = build_targets(model, targets) h = model.hyp # hyperparameters + arc = model.arc # # (default, uCE, uBCE) detection architectures # Define criteria BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']])) @@ -354,7 +355,7 @@ def compute_loss(p, targets, model, arc='default'): # predictions, targets, mod # with open('targets.txt', 'a') as file: # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] - if arc == 'default': # (default, uCE, uBCE) detection architectures + if arc == 'default': lobj += BCEobj(pi[..., 4], tobj) # obj loss elif arc == 'uCE': # unified CE (1 background + 80 classes), hyps 20