updates
This commit is contained in:
parent
2300cb964a
commit
2391996474
|
@ -334,10 +334,11 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
tcls, tbox, indices, anchor_vec = build_targets(model, targets)
|
||||
h = model.hyp # hyperparameters
|
||||
arc = model.arc # # (default, uCE, uBCE) detection architectures
|
||||
red = 'mean' # Loss reduction (sum or mean)
|
||||
|
||||
# Define criteria
|
||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]))
|
||||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]))
|
||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red)
|
||||
BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red)
|
||||
BCE = nn.BCEWithLogitsLoss()
|
||||
CE = nn.CrossEntropyLoss() # weight=model.class_weights
|
||||
|
||||
|
@ -346,13 +347,16 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
BCEcls, BCEobj, BCE, CE = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g), FocalLoss(BCE, g), FocalLoss(CE, g)
|
||||
|
||||
# Compute losses
|
||||
np, ng = 0, 0 # number grid points, targets
|
||||
for i, pi in enumerate(p): # layer index, layer predictions
|
||||
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
||||
tobj = torch.zeros_like(pi[..., 0]) # target obj
|
||||
np += tobj.numel()
|
||||
|
||||
# Compute losses
|
||||
nb = len(b)
|
||||
if nb: # number of targets
|
||||
ng += nb
|
||||
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
|
||||
tobj[b, a, gj, gi] = 1.0 # obj
|
||||
# ps[:, 2:4] = torch.sigmoid(ps[:, 2:4]) # wh power loss (uncomment)
|
||||
|
@ -360,8 +364,8 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
# GIoU
|
||||
pxy = torch.sigmoid(ps[:, 0:2]) # pxy = pxy * s - (s - 1) / 2, s = 1.5 (scale_xy)
|
||||
pbox = torch.cat((pxy, torch.exp(ps[:, 2:4]).clamp(max=1E3) * anchor_vec[i]), 1) # predicted box
|
||||
giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
|
||||
lbox += (1.0 - giou).mean() # giou loss
|
||||
giou = 1.0 - bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation
|
||||
lbox += giou.sum() if red == 'sum' else giou.mean() # giou loss
|
||||
|
||||
if 'default' in arc and model.nc > 1: # cls loss (only if multiple classes)
|
||||
t = torch.zeros_like(ps[:, 5:]) # targets
|
||||
|
@ -396,6 +400,11 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
lbox *= h['giou']
|
||||
lobj *= h['obj']
|
||||
lcls *= h['cls']
|
||||
if red == 'sum':
|
||||
lbox *= 3 / ng
|
||||
lobj *= 3 / np
|
||||
lcls *= 3 / ng / model.nc
|
||||
|
||||
loss = lbox + lobj + lcls
|
||||
return loss, torch.cat((lbox, lobj, lcls, loss)).detach()
|
||||
|
||||
|
|
Loading…
Reference in New Issue