Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2019-09-01 16:28:25 +02:00
parent c8c0660e6a
commit 9251dfd6a5
1 changed files with 1 additions and 2 deletions

View File

@ -301,8 +301,7 @@ class FocalLoss(nn.Module):
def forward(self, input, target):
loss = self.loss_fcn(input, target)
pt = torch.exp(-loss)
loss *= self.alpha * (1 - pt) ** self.gamma
loss *= self.alpha * (1.000001 - torch.exp(-loss)) ** self.gamma # non-zero power for gradient stability
if self.reduction == 'mean':
return loss.mean()