loss function cleanup
This commit is contained in:
parent
f1208f784e
commit
1a12667ce1
|
@ -377,21 +377,19 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
||||||
tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
||||||
h = model.hyp # hyperparameters
|
h = model.hyp # hyperparameters
|
||||||
arc = model.arc # # (default, uCE, uBCE) detection architectures
|
arc = model.arc # architecture
|
||||||
red = 'mean' # Loss reduction (sum or mean)
|
red = 'mean' # Loss reduction (sum or mean)
|
||||||
|
|
||||||
# Define criteria
|
# Define criteria
|
||||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red)
|
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red)
|
||||||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red)
|
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red)
|
||||||
BCE = nn.BCEWithLogitsLoss(reduction=red)
|
|
||||||
CE = nn.CrossEntropyLoss(reduction=red) # weight=model.class_weights
|
|
||||||
|
|
||||||
# class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
# class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
||||||
cp, cn = smooth_BCE(eps=0.0)
|
cp, cn = smooth_BCE(eps=0.0)
|
||||||
|
|
||||||
if 'F' in arc: # add focal loss
|
# focal loss
|
||||||
g = h['fl_gamma']
|
if 'F' in arc:
|
||||||
BCEcls, BCEobj, BCE, CE = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g), FocalLoss(BCE, g), FocalLoss(CE, g)
|
BCEcls, BCEobj = FocalLoss(BCEcls, h['fl_gamma']), FocalLoss(BCEobj, h['fl_gamma'])
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
np, ng = 0, 0 # number grid points, targets
|
np, ng = 0, 0 # number grid points, targets
|
||||||
|
@ -415,8 +413,8 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss
|
lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss
|
||||||
tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().type(tobj.dtype) # giou ratio
|
tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().type(tobj.dtype) # giou ratio
|
||||||
|
|
||||||
if 'default' in arc and model.nc > 1: # cls loss (only if multiple classes)
|
if model.nc > 1: # cls loss (only if multiple classes)
|
||||||
t = torch.zeros_like(ps[:, 5:]) + cn # targets
|
t = torch.full_like(ps[:, 5:], cn) # targets
|
||||||
t[range(nb), tcls[i]] = cp
|
t[range(nb), tcls[i]] = cp
|
||||||
lcls += BCEcls(ps[:, 5:], t) # BCE
|
lcls += BCEcls(ps[:, 5:], t) # BCE
|
||||||
# lcls += CE(ps[:, 5:], tcls[i]) # CE
|
# lcls += CE(ps[:, 5:], tcls[i]) # CE
|
||||||
|
@ -425,20 +423,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
# with open('targets.txt', 'a') as file:
|
# with open('targets.txt', 'a') as file:
|
||||||
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
||||||
|
|
||||||
if 'default' in arc: # separate obj and cls
|
lobj += BCEobj(pi[..., 4], tobj) # obj loss
|
||||||
lobj += BCEobj(pi[..., 4], tobj) # obj loss
|
|
||||||
|
|
||||||
elif 'BCE' in arc: # unified BCE (80 classes)
|
|
||||||
t = torch.zeros_like(pi[..., 5:]) # targets
|
|
||||||
if nb:
|
|
||||||
t[b, a, gj, gi, tcls[i]] = 1.0
|
|
||||||
lobj += BCE(pi[..., 5:], t)
|
|
||||||
|
|
||||||
elif 'CE' in arc: # unified CE (1 background + 80 classes)
|
|
||||||
t = torch.zeros_like(pi[..., 0], dtype=torch.long) # targets
|
|
||||||
if nb:
|
|
||||||
t[b, a, gj, gi] = tcls[i] + 1
|
|
||||||
lcls += CE(pi[..., 4:].view(-1, model.nc + 1), t.view(-1))
|
|
||||||
|
|
||||||
lbox *= h['giou']
|
lbox *= h['giou']
|
||||||
lobj *= h['obj']
|
lobj *= h['obj']
|
||||||
|
|
Loading…
Reference in New Issue