updates
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
c8c0660e6a
commit
9251dfd6a5
|
@ -301,8 +301,7 @@ class FocalLoss(nn.Module):
|
|||
|
||||
def forward(self, input, target):
|
||||
loss = self.loss_fcn(input, target)
|
||||
pt = torch.exp(-loss)
|
||||
loss *= self.alpha * (1 - pt) ** self.gamma
|
||||
loss *= self.alpha * (1.000001 - torch.exp(-loss)) ** self.gamma # non-zero power for gradient stability
|
||||
|
||||
if self.reduction == 'mean':
|
||||
return loss.mean()
|
||||
|
|
Loading…
Reference in New Issue