From 1a12667ce1f0af92b68987e4b48b060a6f738248 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 16 Mar 2020 17:31:37 -0700 Subject: [PATCH] loss function cleanup --- utils/utils.py | 29 +++++++---------------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index 82b75d41..7bb8babb 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -377,21 +377,19 @@ def compute_loss(p, targets, model): # predictions, targets, model lcls, lbox, lobj = ft([0]), ft([0]), ft([0]) tcls, tbox, indices, anchor_vec = build_targets(model, targets) h = model.hyp # hyperparameters - arc = model.arc # # (default, uCE, uBCE) detection architectures + arc = model.arc # architecture red = 'mean' # Loss reduction (sum or mean) # Define criteria BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red) BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red) - BCE = nn.BCEWithLogitsLoss(reduction=red) - CE = nn.CrossEntropyLoss(reduction=red) # weight=model.class_weights # class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 cp, cn = smooth_BCE(eps=0.0) - if 'F' in arc: # add focal loss - g = h['fl_gamma'] - BCEcls, BCEobj, BCE, CE = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g), FocalLoss(BCE, g), FocalLoss(CE, g) + # focal loss + if 'F' in arc: + BCEcls, BCEobj = FocalLoss(BCEcls, h['fl_gamma']), FocalLoss(BCEobj, h['fl_gamma']) # Compute losses np, ng = 0, 0 # number grid points, targets @@ -415,8 +413,8 @@ def compute_loss(p, targets, model): # predictions, targets, model lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().type(tobj.dtype) # giou ratio - if 'default' in arc and model.nc > 1: # cls loss (only if multiple classes) - t = torch.zeros_like(ps[:, 5:]) + cn # targets + if model.nc > 1: # cls loss (only if multiple classes) + t = torch.full_like(ps[:, 5:], cn) # targets t[range(nb), tcls[i]] = cp lcls += BCEcls(ps[:, 5:], t) # BCE # lcls += CE(ps[:, 5:], tcls[i]) # CE @@ -425,20 +423,7 @@ def compute_loss(p, targets, model): # predictions, targets, model # with open('targets.txt', 'a') as file: # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] - if 'default' in arc: # separate obj and cls - lobj += BCEobj(pi[..., 4], tobj) # obj loss - - elif 'BCE' in arc: # unified BCE (80 classes) - t = torch.zeros_like(pi[..., 5:]) # targets - if nb: - t[b, a, gj, gi, tcls[i]] = 1.0 - lobj += BCE(pi[..., 5:], t) - - elif 'CE' in arc: # unified CE (1 background + 80 classes) - t = torch.zeros_like(pi[..., 0], dtype=torch.long) # targets - if nb: - t[b, a, gj, gi] = tcls[i] + 1 - lcls += CE(pi[..., 4:].view(-1, model.nc + 1), t.view(-1)) + lobj += BCEobj(pi[..., 4], tobj) # obj loss lbox *= h['giou'] lobj *= h['obj']