updates
This commit is contained in:
		
							parent
							
								
									100f443722
								
							
						
					
					
						commit
						a70c9f87a9
					
				|  | @ -251,26 +251,26 @@ def compute_loss(p, targets):  # predictions, targets | ||||||
|     BCE = nn.BCEWithLogitsLoss() |     BCE = nn.BCEWithLogitsLoss() | ||||||
| 
 | 
 | ||||||
|     # Compute losses |     # Compute losses | ||||||
|     # bs = p[0].shape[0]  # batch size |     bs = p[0].shape[0]  # batch size | ||||||
|     # gp = [x.numel() for x in tconf]  # grid points |     # gp = [x.numel() for x in tconf]  # grid points | ||||||
|     for i, pi0 in enumerate(p):  # layer i predictions, i |     for i, pi0 in enumerate(p):  # layer i predictions, i | ||||||
|         b, a, gj, gi = indices[i]  # image, anchor, gridx, gridy |         b, a, gj, gi = indices[i]  # image, anchor, gridx, gridy | ||||||
|         tconf = torch.zeros_like(pi0[..., 0])  # conf |         tconf = torch.zeros_like(pi0[..., 0])  # conf | ||||||
| 
 | 
 | ||||||
|         # Compute losses |         # Compute losses | ||||||
|         k = 135.8 |         k = 8.4875 * bs | ||||||
|         if len(b):  # number of targets |         if len(b):  # number of targets | ||||||
|             pi = pi0[b, a, gj, gi]  # predictions closest to anchors |             pi = pi0[b, a, gj, gi]  # predictions closest to anchors | ||||||
|             tconf[b, a, gj, gi] = 1  # conf |             tconf[b, a, gj, gi] = 1  # conf | ||||||
| 
 | 
 | ||||||
|             lxy += (k * 0.07997) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i])  # xy loss |             lxy += (k * 0.079756) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i])  # xy loss | ||||||
|             lwh += (k * 0.007867) * MSE(pi[..., 2:4], twh[i])  # wh yolo loss |             lwh += (k * 0.010461) * MSE(pi[..., 2:4], twh[i])  # wh yolo loss | ||||||
|             # lwh += (k * 0.007867) * MSE(torch.sigmoid(pi[..., 2:4]), twh[i])  # wh power loss |             # lwh += (k * 0.010461) * MSE(torch.sigmoid(pi[..., 2:4]), twh[i])  # wh power loss | ||||||
|             lcls += (k * 0.02111) * CE(pi[..., 5:], tcls[i])  # class_conf loss |             lcls += (k * 0.02105) * CE(pi[..., 5:], tcls[i])  # class_conf loss | ||||||
| 
 | 
 | ||||||
|         # pos_weight = ft([gp[i] / min(gp) * 4.]) |         # pos_weight = ft([gp[i] / min(gp) * 4.]) | ||||||
|         # BCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight) |         # BCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight) | ||||||
|         lconf += (k * 0.8911) * BCE(pi0[..., 4], tconf)  # obj_conf loss |         lconf += (k * 0.88873) * BCE(pi0[..., 4], tconf)  # obj_conf loss | ||||||
|     loss = lxy + lwh + lconf + lcls |     loss = lxy + lwh + lconf + lcls | ||||||
| 
 | 
 | ||||||
|     return loss, torch.cat((lxy, lwh, lconf, lcls, loss)).detach() |     return loss, torch.cat((lxy, lwh, lconf, lcls, loss)).detach() | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue