This commit is contained in:
Glenn Jocher 2019-08-18 01:58:35 +02:00
parent fd2991386f
commit 0aece25ef6
2 changed files with 23 additions and 23 deletions

View File

@ -59,26 +59,3 @@ def fuse_conv_and_bn(conv, bn):
fusedconv.bias.copy_(b_conv + b_bn)
return fusedconv
class FocalLoss(nn.Module):
# Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf
# i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5)
def __init__(self, loss_fcn, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.loss_fcn = loss_fcn
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, input, target):
loss = self.loss_fcn(input, target, reduction='none')
pt = torch.exp(-loss)
loss *= self.alpha * (1 - pt) ** self.gamma
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else: # 'none'
return loss

View File

@ -288,6 +288,29 @@ def wh_iou(box1, box2):
return inter_area / union_area # iou
class FocalLoss(nn.Module):
# Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf
# i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5)
def __init__(self, loss_fcn, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.loss_fcn = loss_fcn
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, input, target):
loss = self.loss_fcn(input, target, reduction='none')
pt = torch.exp(-loss)
loss *= self.alpha * (1 - pt) ** self.gamma
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else: # 'none'
return loss
def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, model
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lcls, lbox, lobj = ft([0]), ft([0]), ft([0])