updates
This commit is contained in:
		
							parent
							
								
									3bac3c63b1
								
							
						
					
					
						commit
						43956d6305
					
				
							
								
								
									
										4
									
								
								train.py
								
								
								
								
							
							
						
						
									
										4
									
								
								train.py
								
								
								
								
							|  | @ -221,7 +221,7 @@ def train(): | |||
| 
 | ||||
|         # Prebias | ||||
|         if prebias: | ||||
|             if epoch < 3:  # prebias | ||||
|             if epoch < 1:  # prebias | ||||
|                 ps = 0.1, 0.9  # prebias settings (lr=0.1, momentum=0.9) | ||||
|             else:  # normal training | ||||
|                 ps = hyp['lr0'], hyp['momentum']  # normal training settings | ||||
|  | @ -278,7 +278,7 @@ def train(): | |||
|             pred = model(imgs) | ||||
| 
 | ||||
|             # Compute loss | ||||
|             loss, loss_items = compute_loss(pred, targets, model) | ||||
|             loss, loss_items = compute_loss(pred, targets, model, not prebias) | ||||
|             if not torch.isfinite(loss): | ||||
|                 print('WARNING: non-finite loss, ending training ', loss_items) | ||||
|                 return results | ||||
|  |  | |||
|  | @ -361,7 +361,7 @@ class FocalLoss(nn.Module): | |||
|             return loss | ||||
| 
 | ||||
| 
 | ||||
| def compute_loss(p, targets, model):  # predictions, targets, model | ||||
| def compute_loss(p, targets, model, giou_flag=True):  # predictions, targets, model | ||||
|     ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor | ||||
|     lcls, lbox, lobj = ft([0]), ft([0]), ft([0]) | ||||
|     tcls, tbox, indices, anchor_vec = build_targets(model, targets) | ||||
|  | @ -399,7 +399,7 @@ def compute_loss(p, targets, model):  # predictions, targets, model | |||
|             pbox = torch.cat((pxy, pwh), 1)  # predicted box | ||||
|             giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)  # giou computation | ||||
|             lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean()  # giou loss | ||||
|             tobj[b, a, gj, gi] = giou.detach().type(tobj.dtype) | ||||
|             tobj[b, a, gj, gi] = giou.detach().type(tobj.dtype) if giou_flag else 1.0 | ||||
| 
 | ||||
|             if 'default' in arc and model.nc > 1:  # cls loss (only if multiple classes) | ||||
|                 t = torch.zeros_like(ps[:, 5:])  # targets | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue