diff --git a/utils/utils.py b/utils/utils.py index 93653786..f7905c8b 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -301,8 +301,7 @@ class FocalLoss(nn.Module): def forward(self, input, target): loss = self.loss_fcn(input, target) - pt = torch.exp(-loss) - loss *= self.alpha * (1 - pt) ** self.gamma + loss *= self.alpha * (1.000001 - torch.exp(-loss)) ** self.gamma # non-zero power for gradient stability if self.reduction == 'mean': return loss.mean()