loss function cleanup
This commit is contained in:
		
							parent
							
								
									f1208f784e
								
							
						
					
					
						commit
						1a12667ce1
					
				|  | @ -377,21 +377,19 @@ def compute_loss(p, targets, model):  # predictions, targets, model | |||
|     lcls, lbox, lobj = ft([0]), ft([0]), ft([0]) | ||||
|     tcls, tbox, indices, anchor_vec = build_targets(model, targets) | ||||
|     h = model.hyp  # hyperparameters | ||||
|     arc = model.arc  # # (default, uCE, uBCE) detection architectures | ||||
|     arc = model.arc  # architecture | ||||
|     red = 'mean'  # Loss reduction (sum or mean) | ||||
| 
 | ||||
|     # Define criteria | ||||
|     BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction=red) | ||||
|     BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction=red) | ||||
|     BCE = nn.BCEWithLogitsLoss(reduction=red) | ||||
|     CE = nn.CrossEntropyLoss(reduction=red)  # weight=model.class_weights | ||||
| 
 | ||||
|     # class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 | ||||
|     cp, cn = smooth_BCE(eps=0.0) | ||||
| 
 | ||||
|     if 'F' in arc:  # add focal loss | ||||
|         g = h['fl_gamma'] | ||||
|         BCEcls, BCEobj, BCE, CE = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g), FocalLoss(BCE, g), FocalLoss(CE, g) | ||||
|     # focal loss | ||||
|     if 'F' in arc: | ||||
|         BCEcls, BCEobj = FocalLoss(BCEcls, h['fl_gamma']), FocalLoss(BCEobj, h['fl_gamma']) | ||||
| 
 | ||||
|     # Compute losses | ||||
|     np, ng = 0, 0  # number grid points, targets | ||||
|  | @ -415,8 +413,8 @@ def compute_loss(p, targets, model):  # predictions, targets, model | |||
|             lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean()  # giou loss | ||||
|             tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().type(tobj.dtype)  # giou ratio | ||||
| 
 | ||||
|             if 'default' in arc and model.nc > 1:  # cls loss (only if multiple classes) | ||||
|                 t = torch.zeros_like(ps[:, 5:]) + cn  # targets | ||||
|             if model.nc > 1:  # cls loss (only if multiple classes) | ||||
|                 t = torch.full_like(ps[:, 5:], cn)  # targets | ||||
|                 t[range(nb), tcls[i]] = cp | ||||
|                 lcls += BCEcls(ps[:, 5:], t)  # BCE | ||||
|                 # lcls += CE(ps[:, 5:], tcls[i])  # CE | ||||
|  | @ -425,21 +423,8 @@ def compute_loss(p, targets, model):  # predictions, targets, model | |||
|             # with open('targets.txt', 'a') as file: | ||||
|             #     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] | ||||
| 
 | ||||
|         if 'default' in arc:  # separate obj and cls | ||||
|         lobj += BCEobj(pi[..., 4], tobj)  # obj loss | ||||
| 
 | ||||
|         elif 'BCE' in arc:  # unified BCE (80 classes) | ||||
|             t = torch.zeros_like(pi[..., 5:])  # targets | ||||
|             if nb: | ||||
|                 t[b, a, gj, gi, tcls[i]] = 1.0 | ||||
|             lobj += BCE(pi[..., 5:], t) | ||||
| 
 | ||||
|         elif 'CE' in arc:  # unified CE (1 background + 80 classes) | ||||
|             t = torch.zeros_like(pi[..., 0], dtype=torch.long)  # targets | ||||
|             if nb: | ||||
|                 t[b, a, gj, gi] = tcls[i] + 1 | ||||
|             lcls += CE(pi[..., 4:].view(-1, model.nc + 1), t.view(-1)) | ||||
| 
 | ||||
|     lbox *= h['giou'] | ||||
|     lobj *= h['obj'] | ||||
|     lcls *= h['cls'] | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue