From 2f256ee27466283224921ff83cdab9fd81795634 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 23 Aug 2019 17:24:50 +0200 Subject: [PATCH] updates --- utils/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index 38ad3242..0b28d7c9 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -321,8 +321,8 @@ def compute_loss(p, targets, model, arc='default'): # predictions, targets, mod # Define criteria BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']])) BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']])) - BCE = nn.BCEWithLogitsLoss() - CE = nn.CrossEntropyLoss() # weight=model.class_weights + FBCE = FocalLoss(nn.BCEWithLogitsLoss()) + FCE = FocalLoss(nn.CrossEntropyLoss()) # weight=model.class_weights # Compute losses bs = p[0].shape[0] # batch size @@ -361,13 +361,13 @@ def compute_loss(p, targets, model, arc='default'): # predictions, targets, mod t = torch.zeros_like(pi[..., 0], dtype=torch.long) # targets if nb: t[b, a, gj, gi] = tcls[i] + 1 - lcls += CE(pi[..., 4:].view(-1, model.nc + 1), t.view(-1)) + lcls += FCE(pi[..., 4:].view(-1, model.nc + 1), t.view(-1)) elif arc == 'uBCE': # unified BCE (1 background + 80 classes), hyps 200-30 t = torch.zeros_like(pi[..., 5:]) # targets if nb: t[b, a, gj, gi, tcls[i]] = 1.0 - lobj += BCE(pi[..., 5:], t) + lobj += FBCE(pi[..., 5:], t) lbox *= k * h['giou'] lobj *= k * h['obj']