updates
This commit is contained in:
		
							parent
							
								
									100f443722
								
							
						
					
					
						commit
						a70c9f87a9
					
				|  | @ -251,26 +251,26 @@ def compute_loss(p, targets):  # predictions, targets | |||
|     BCE = nn.BCEWithLogitsLoss() | ||||
| 
 | ||||
|     # Compute losses | ||||
|     # bs = p[0].shape[0]  # batch size | ||||
|     bs = p[0].shape[0]  # batch size | ||||
|     # gp = [x.numel() for x in tconf]  # grid points | ||||
|     for i, pi0 in enumerate(p):  # layer i predictions, i | ||||
|         b, a, gj, gi = indices[i]  # image, anchor, gridx, gridy | ||||
|         tconf = torch.zeros_like(pi0[..., 0])  # conf | ||||
| 
 | ||||
|         # Compute losses | ||||
|         k = 135.8 | ||||
|         k = 8.4875 * bs | ||||
|         if len(b):  # number of targets | ||||
|             pi = pi0[b, a, gj, gi]  # predictions closest to anchors | ||||
|             tconf[b, a, gj, gi] = 1  # conf | ||||
| 
 | ||||
|             lxy += (k * 0.07997) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i])  # xy loss | ||||
|             lwh += (k * 0.007867) * MSE(pi[..., 2:4], twh[i])  # wh yolo loss | ||||
|             # lwh += (k * 0.007867) * MSE(torch.sigmoid(pi[..., 2:4]), twh[i])  # wh power loss | ||||
|             lcls += (k * 0.02111) * CE(pi[..., 5:], tcls[i])  # class_conf loss | ||||
|             lxy += (k * 0.079756) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i])  # xy loss | ||||
|             lwh += (k * 0.010461) * MSE(pi[..., 2:4], twh[i])  # wh yolo loss | ||||
|             # lwh += (k * 0.010461) * MSE(torch.sigmoid(pi[..., 2:4]), twh[i])  # wh power loss | ||||
|             lcls += (k * 0.02105) * CE(pi[..., 5:], tcls[i])  # class_conf loss | ||||
| 
 | ||||
|         # pos_weight = ft([gp[i] / min(gp) * 4.]) | ||||
|         # BCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight) | ||||
|         lconf += (k * 0.8911) * BCE(pi0[..., 4], tconf)  # obj_conf loss | ||||
|         lconf += (k * 0.88873) * BCE(pi0[..., 4], tconf)  # obj_conf loss | ||||
|     loss = lxy + lwh + lconf + lcls | ||||
| 
 | ||||
|     return loss, torch.cat((lxy, lwh, lconf, lcls, loss)).detach() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue