updates
This commit is contained in:
parent
fd2991386f
commit
0aece25ef6
|
@ -59,26 +59,3 @@ def fuse_conv_and_bn(conv, bn):
|
||||||
fusedconv.bias.copy_(b_conv + b_bn)
|
fusedconv.bias.copy_(b_conv + b_bn)
|
||||||
|
|
||||||
return fusedconv
|
return fusedconv
|
||||||
|
|
||||||
|
|
||||||
class FocalLoss(nn.Module):
|
|
||||||
# Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf
|
|
||||||
# i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5)
|
|
||||||
def __init__(self, loss_fcn, alpha=1, gamma=2, reduction='mean'):
|
|
||||||
super(FocalLoss, self).__init__()
|
|
||||||
self.loss_fcn = loss_fcn
|
|
||||||
self.alpha = alpha
|
|
||||||
self.gamma = gamma
|
|
||||||
self.reduction = reduction
|
|
||||||
|
|
||||||
def forward(self, input, target):
|
|
||||||
loss = self.loss_fcn(input, target, reduction='none')
|
|
||||||
pt = torch.exp(-loss)
|
|
||||||
loss *= self.alpha * (1 - pt) ** self.gamma
|
|
||||||
|
|
||||||
if self.reduction == 'mean':
|
|
||||||
return loss.mean()
|
|
||||||
elif self.reduction == 'sum':
|
|
||||||
return loss.sum()
|
|
||||||
else: # 'none'
|
|
||||||
return loss
|
|
||||||
|
|
|
@ -288,6 +288,29 @@ def wh_iou(box1, box2):
|
||||||
return inter_area / union_area # iou
|
return inter_area / union_area # iou
|
||||||
|
|
||||||
|
|
||||||
|
class FocalLoss(nn.Module):
|
||||||
|
# Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf
|
||||||
|
# i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5)
|
||||||
|
def __init__(self, loss_fcn, alpha=1, gamma=2, reduction='mean'):
|
||||||
|
super(FocalLoss, self).__init__()
|
||||||
|
self.loss_fcn = loss_fcn
|
||||||
|
self.alpha = alpha
|
||||||
|
self.gamma = gamma
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
def forward(self, input, target):
|
||||||
|
loss = self.loss_fcn(input, target, reduction='none')
|
||||||
|
pt = torch.exp(-loss)
|
||||||
|
loss *= self.alpha * (1 - pt) ** self.gamma
|
||||||
|
|
||||||
|
if self.reduction == 'mean':
|
||||||
|
return loss.mean()
|
||||||
|
elif self.reduction == 'sum':
|
||||||
|
return loss.sum()
|
||||||
|
else: # 'none'
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, model
|
def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, model
|
||||||
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
|
||||||
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])
|
||||||
|
|
Loading…
Reference in New Issue