updates
This commit is contained in:
		
							parent
							
								
									5f2b551818
								
							
						
					
					
						commit
						2f256ee274
					
				|  | @ -321,8 +321,8 @@ def compute_loss(p, targets, model, arc='default'):  # predictions, targets, mod | |||
|     # Define criteria | ||||
|     BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']])) | ||||
|     BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']])) | ||||
|     BCE = nn.BCEWithLogitsLoss() | ||||
|     CE = nn.CrossEntropyLoss()  # weight=model.class_weights | ||||
|     FBCE = FocalLoss(nn.BCEWithLogitsLoss()) | ||||
|     FCE = FocalLoss(nn.CrossEntropyLoss())  # weight=model.class_weights | ||||
| 
 | ||||
|     # Compute losses | ||||
|     bs = p[0].shape[0]  # batch size | ||||
|  | @ -361,13 +361,13 @@ def compute_loss(p, targets, model, arc='default'):  # predictions, targets, mod | |||
|             t = torch.zeros_like(pi[..., 0], dtype=torch.long)  # targets | ||||
|             if nb: | ||||
|                 t[b, a, gj, gi] = tcls[i] + 1 | ||||
|             lcls += CE(pi[..., 4:].view(-1, model.nc + 1), t.view(-1)) | ||||
|             lcls += FCE(pi[..., 4:].view(-1, model.nc + 1), t.view(-1)) | ||||
| 
 | ||||
|         elif arc == 'uBCE':  # unified BCE (1 background + 80 classes), hyps 200-30 | ||||
|             t = torch.zeros_like(pi[..., 5:])  # targets | ||||
|             if nb: | ||||
|                 t[b, a, gj, gi, tcls[i]] = 1.0 | ||||
|             lobj += BCE(pi[..., 5:], t) | ||||
|             lobj += FBCE(pi[..., 5:], t) | ||||
| 
 | ||||
|     lbox *= k * h['giou'] | ||||
|     lobj *= k * h['obj'] | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue