updates
This commit is contained in:
		
							parent
							
								
									6ab753a9e7
								
							
						
					
					
						commit
						981b452b1d
					
				
							
								
								
									
										3
									
								
								train.py
								
								
								
								
							
							
						
						
									
										3
									
								
								train.py
								
								
								
								
							| 
						 | 
					@ -211,6 +211,7 @@ def train():
 | 
				
			||||||
    print('Starting training for %g epochs...' % epochs)
 | 
					    print('Starting training for %g epochs...' % epochs)
 | 
				
			||||||
    for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
 | 
					    for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
 | 
				
			||||||
        model.train()
 | 
					        model.train()
 | 
				
			||||||
 | 
					        model.hyps['gr'] = 1 - (1 + math.cos(min(epoch * 2, epochs) * math.pi / epochs)) / 2  # GIoU <-> 1.0 ratio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Prebias
 | 
					        # Prebias
 | 
				
			||||||
        if prebias:
 | 
					        if prebias:
 | 
				
			||||||
| 
						 | 
					@ -271,7 +272,7 @@ def train():
 | 
				
			||||||
            pred = model(imgs)
 | 
					            pred = model(imgs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Compute loss
 | 
					            # Compute loss
 | 
				
			||||||
            loss, loss_items = compute_loss(pred, targets, model, not prebias)
 | 
					            loss, loss_items = compute_loss(pred, targets, model)
 | 
				
			||||||
            if not torch.isfinite(loss):
 | 
					            if not torch.isfinite(loss):
 | 
				
			||||||
                print('WARNING: non-finite loss, ending training ', loss_items)
 | 
					                print('WARNING: non-finite loss, ending training ', loss_items)
 | 
				
			||||||
                return results
 | 
					                return results
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -363,7 +363,7 @@ class FocalLoss(nn.Module):
 | 
				
			||||||
            return loss
 | 
					            return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def compute_loss(p, targets, model, giou_flag=True):  # predictions, targets, model
 | 
					def compute_loss(p, targets, model):  # predictions, targets, model
 | 
				
			||||||
    ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
 | 
					    ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
 | 
				
			||||||
    lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
 | 
					    lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
 | 
				
			||||||
    tcls, tbox, indices, anchor_vec = build_targets(model, targets)
 | 
					    tcls, tbox, indices, anchor_vec = build_targets(model, targets)
 | 
				
			||||||
| 
						 | 
					@ -401,7 +401,7 @@ def compute_loss(p, targets, model, giou_flag=True):  # predictions, targets, mo
 | 
				
			||||||
            pbox = torch.cat((pxy, pwh), 1)  # predicted box
 | 
					            pbox = torch.cat((pxy, pwh), 1)  # predicted box
 | 
				
			||||||
            giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)  # giou computation
 | 
					            giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True)  # giou computation
 | 
				
			||||||
            lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean()  # giou loss
 | 
					            lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean()  # giou loss
 | 
				
			||||||
            tobj[b, a, gj, gi] = giou.detach().clamp(0).type(tobj.dtype) if giou_flag else 1.0
 | 
					            tobj[b, a, gj, gi] = (1.0 - h['gr']) + h['gr'] * giou.detach().clamp(0).type(tobj.dtype)  # giou ratio
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if 'default' in arc and model.nc > 1:  # cls loss (only if multiple classes)
 | 
					            if 'default' in arc and model.nc > 1:  # cls loss (only if multiple classes)
 | 
				
			||||||
                t = torch.zeros_like(ps[:, 5:])  # targets
 | 
					                t = torch.zeros_like(ps[:, 5:])  # targets
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue