diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 40eed4ae..03337985 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -1,4 +1,5 @@ import torch +import torch.nn as nn def init_seeds(seed=0): @@ -58,3 +59,26 @@ def fuse_conv_and_bn(conv, bn): fusedconv.bias.copy_(b_conv + b_bn) return fusedconv + + +class FocalLoss(nn.Module): + # Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf + # i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5) + def __init__(self, loss_fcn, alpha=1, gamma=2, reduction='mean'): + super(FocalLoss, self).__init__() + self.loss_fcn = loss_fcn + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + + def forward(self, input, target): + loss = self.loss_fcn(input, target, reduction='none') + pt = torch.exp(-loss) + loss *= self.alpha * (1 - pt) ** self.gamma + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: # 'none' + return loss