equal layer weights
This commit is contained in:
		
							parent
							
								
									5886200401
								
							
						
					
					
						commit
						bd9789aa00
					
				
							
								
								
									
										4
									
								
								train.py
								
								
								
								
							
							
						
						
									
										4
									
								
								train.py
								
								
								
								
							|  | @ -348,14 +348,14 @@ if __name__ == '__main__': | ||||||
| 
 | 
 | ||||||
|             # Mutate |             # Mutate | ||||||
|             init_seeds(seed=int(time.time())) |             init_seeds(seed=int(time.time())) | ||||||
|             s = [.15, .15, .15, .15, .15, .15, .15, .15, 0, 0, 0, 0]  # fractional sigmas |             s = [.15, .15, .15, .15, .15, .15, .15, .15, .15, .15, .15, .15]  # fractional sigmas | ||||||
|             for i, k in enumerate(hyp.keys()): |             for i, k in enumerate(hyp.keys()): | ||||||
|                 x = (np.random.randn(1) * s[i] + 1) ** 2.0  # plt.hist(x.ravel(), 300) |                 x = (np.random.randn(1) * s[i] + 1) ** 2.0  # plt.hist(x.ravel(), 300) | ||||||
|                 hyp[k] *= float(x)  # vary by 20% 1sigma |                 hyp[k] *= float(x)  # vary by 20% 1sigma | ||||||
| 
 | 
 | ||||||
|             # Clip to limits |             # Clip to limits | ||||||
|             keys = ['lr0', 'iou_t', 'momentum', 'weight_decay'] |             keys = ['lr0', 'iou_t', 'momentum', 'weight_decay'] | ||||||
|             limits = [(1e-4, 1e-2), (0, 0.70), (0.70, 0.98), (0, 0.01)] |             limits = [(1e-4, 1e-2), (0.00, 0.70), (0.60, 0.98), (0, 0.01)] | ||||||
|             for k, v in zip(keys, limits): |             for k, v in zip(keys, limits): | ||||||
|                 hyp[k] = np.clip(hyp[k], v[0], v[1]) |                 hyp[k] = np.clip(hyp[k], v[0], v[1]) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -274,28 +274,24 @@ def wh_iou(box1, box2): | ||||||
| def compute_loss(p, targets, model, giou_loss=True):  # predictions, targets, model | def compute_loss(p, targets, model, giou_loss=True):  # predictions, targets, model | ||||||
|     ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor |     ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor | ||||||
|     lxy, lwh, lcls, lobj = ft([0]), ft([0]), ft([0]), ft([0]) |     lxy, lwh, lcls, lobj = ft([0]), ft([0]), ft([0]), ft([0]) | ||||||
|     txy, twh, tcls, tbox, indices, anchor_vec, nc = build_targets(model, targets) |     txy, twh, tcls, tbox, indices, anchor_vec = build_targets(model, targets) | ||||||
|     h = model.hyp  # hyperparameters |     h = model.hyp  # hyperparameters | ||||||
| 
 | 
 | ||||||
|     # Define criteria |     # Define criteria | ||||||
|     MSE = nn.MSELoss(reduction='sum') |     MSE = nn.MSELoss() | ||||||
|     BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']]), reduction='sum') |     BCEcls = nn.BCEWithLogitsLoss(pos_weight=ft([h['cls_pw']])) | ||||||
|     BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']]), reduction='sum') |     BCEobj = nn.BCEWithLogitsLoss(pos_weight=ft([h['obj_pw']])) | ||||||
|     # CE = nn.CrossEntropyLoss()  # (weight=model.class_weights) |     CE = nn.CrossEntropyLoss()  # (weight=model.class_weights) | ||||||
| 
 | 
 | ||||||
|     # Compute losses |     # Compute losses | ||||||
|     bs = p[0].shape[0]  # batch size |     bs = p[0].shape[0]  # batch size | ||||||
|     k = 3 * bs / 64  # loss gain |     k = bs / 64  # loss gain | ||||||
|     nt, ng = 0, 0  # number of targets, number grid points |  | ||||||
|     for i, pi0 in enumerate(p):  # layer i predictions, i |     for i, pi0 in enumerate(p):  # layer i predictions, i | ||||||
|         b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx |         b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx | ||||||
|         tobj = torch.zeros_like(pi0[..., 0])  # target obj |         tobj = torch.zeros_like(pi0[..., 0])  # target obj | ||||||
|         ng += tobj.numel() |  | ||||||
|         nb = len(b) |  | ||||||
| 
 | 
 | ||||||
|         # Compute losses |         # Compute losses | ||||||
|         if nb:  # number of targets |         if len(b):  # number of targets | ||||||
|             nt += nb |  | ||||||
|             pi = pi0[b, a, gj, gi]  # predictions closest to anchors |             pi = pi0[b, a, gj, gi]  # predictions closest to anchors | ||||||
|             tobj[b, a, gj, gi] = 1.0  # obj |             tobj[b, a, gj, gi] = 1.0  # obj | ||||||
|             # pi[..., 2:4] = torch.sigmoid(pi[..., 2:4])  # wh power loss (uncomment) |             # pi[..., 2:4] = torch.sigmoid(pi[..., 2:4])  # wh power loss (uncomment) | ||||||
|  | @ -303,27 +299,21 @@ def compute_loss(p, targets, model, giou_loss=True):  # predictions, targets, mo | ||||||
|             if giou_loss: |             if giou_loss: | ||||||
|                 pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1)  # predicted |                 pbox = torch.cat((torch.sigmoid(pi[..., 0:2]), torch.exp(pi[..., 2:4]) * anchor_vec[i]), 1)  # predicted | ||||||
|                 giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)  # giou computation |                 giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)  # giou computation | ||||||
|                 lxy += (1.0 - giou).sum()  # giou loss |                 lxy += (k * h['giou']) * (1.0 - giou).mean()  # giou loss | ||||||
|             else: |             else: | ||||||
|                 lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i])  # xy loss |                 lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i])  # xy loss | ||||||
|                 lwh += MSE(pi[..., 2:4], twh[i])  # wh yolo loss |                 lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i])  # wh yolo loss | ||||||
| 
 | 
 | ||||||
|             tclsm = torch.zeros_like(pi[..., 5:]) |             # tclsm = torch.zeros_like(pi[..., 5:]) | ||||||
|             tclsm[range(nb), tcls[i]] = 1.0 |             # tclsm[range(len(b)), tcls[i]] = 1.0 | ||||||
|             lcls += BCEcls(pi[..., 5:], tclsm)  # cls loss (BCE) |             # lcls += (k * h['cls']) * BCEcls(pi[..., 5:], tclsm)  # cls loss (BCE) | ||||||
|             # lcls += CE(pi[..., 5:], tcls[i])  # cls loss (CE) |             lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i])  # cls loss (CE) | ||||||
| 
 | 
 | ||||||
|             # Append targets to text file |             # Append targets to text file | ||||||
|             # with open('targets.txt', 'a') as file: |             # with open('targets.txt', 'a') as file: | ||||||
|             #     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] |             #     [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] | ||||||
| 
 | 
 | ||||||
|         lobj += BCEobj(pi0[..., 4], tobj)  # obj loss |         lobj += (k * h['obj']) * BCEobj(pi0[..., 4], tobj)  # obj loss | ||||||
| 
 |  | ||||||
|     lxy *= (k * h['giou']) / nt |  | ||||||
|     lwh *= (k * h['wh']) / nt |  | ||||||
|     lcls *= (k * h['cls']) / (nt * nc) |  | ||||||
|     lobj *= (k * h['obj']) / ng |  | ||||||
| 
 |  | ||||||
|     loss = lxy + lwh + lobj + lcls |     loss = lxy + lwh + lobj + lcls | ||||||
| 
 | 
 | ||||||
|     return loss, torch.cat((lxy, lwh, lobj, lcls, loss)).detach() |     return loss, torch.cat((lxy, lwh, lobj, lcls, loss)).detach() | ||||||
|  | @ -385,7 +375,7 @@ def build_targets(model, targets): | ||||||
|         if c.shape[0]: |         if c.shape[0]: | ||||||
|             assert c.max() <= layer.nc, 'Target classes exceed model classes' |             assert c.max() <= layer.nc, 'Target classes exceed model classes' | ||||||
| 
 | 
 | ||||||
|     return txy, twh, tcls, tbox, indices, anchor_vec, layer.nc |     return txy, twh, tcls, tbox, indices, anchor_vec | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): | def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue