updates
This commit is contained in:
		
							parent
							
								
									bd9789aa00
								
							
						
					
					
						commit
						c77b87489c
					
				|  | @ -281,7 +281,7 @@ def compute_loss(p, targets, model, giou_loss=True):  # predictions, targets, mo | ||||||
|     MSE = nn.MSELoss() |     MSE = nn.MSELoss() | ||||||
|     BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']])) |     BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']])) | ||||||
|     BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']])) |     BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']])) | ||||||
|     CE = nn.CrossEntropyLoss()  # (weight=model.class_weights) |     # CE = nn.CrossEntropyLoss()  # (weight=model.class_weights) | ||||||
| 
 | 
 | ||||||
|     # Compute losses |     # Compute losses | ||||||
|     bs = p[0].shape[0]  # batch size |     bs = p[0].shape[0]  # batch size | ||||||
|  | @ -291,7 +291,8 @@ def compute_loss(p, targets, model, giou_loss=True):  # predictions, targets, mo | ||||||
|         tobj = torch.zeros_like(pi0[..., 0])  # target obj |         tobj = torch.zeros_like(pi0[..., 0])  # target obj | ||||||
| 
 | 
 | ||||||
|         # Compute losses |         # Compute losses | ||||||
|         if len(b):  # number of targets |         nb = len(b) | ||||||
|  |         if nb:  # number of targets | ||||||
|             pi = pi0[b, a, gj, gi]  # predictions closest to anchors |             pi = pi0[b, a, gj, gi]  # predictions closest to anchors | ||||||
|             tobj[b, a, gj, gi] = 1.0  # obj |             tobj[b, a, gj, gi] = 1.0  # obj | ||||||
|             # pi[..., 2:4] = torch.sigmoid(pi[..., 2:4])  # wh power loss (uncomment) |             # pi[..., 2:4] = torch.sigmoid(pi[..., 2:4])  # wh power loss (uncomment) | ||||||
|  | @ -304,10 +305,10 @@ def compute_loss(p, targets, model, giou_loss=True):  # predictions, targets, mo | ||||||
|                 lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i])  # xy loss |                 lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i])  # xy loss | ||||||
|                 lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i])  # wh yolo loss |                 lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i])  # wh yolo loss | ||||||
| 
 | 
 | ||||||
|             # tclsm = torch.zeros_like(pi[..., 5:]) |             tclsm = torch.zeros_like(pi[..., 5:]) | ||||||
|             # tclsm[range(len(b)), tcls[i]] = 1.0 |             tclsm[range(nb), tcls[i]] = 1.0 | ||||||
|             # lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm)  # cls loss (BCE) |             lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm)  # cls loss (BCE) | ||||||
|             lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i])  # cls loss (CE) |             # lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i])  # cls loss (CE) | ||||||
| 
 | 
 | ||||||
|             # Append targets to text file |             # Append targets to text file | ||||||
|             # with open('targets.txt', 'a') as file: |             # with open('targets.txt', 'a') as file: | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue