updates
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
c8c0660e6a
commit
9251dfd6a5
|
@ -301,8 +301,7 @@ class FocalLoss(nn.Module):
|
||||||
|
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
loss = self.loss_fcn(input, target)
|
loss = self.loss_fcn(input, target)
|
||||||
pt = torch.exp(-loss)
|
loss *= self.alpha * (1.000001 - torch.exp(-loss)) ** self.gamma # non-zero power for gradient stability
|
||||||
loss *= self.alpha * (1 - pt) ** self.gamma
|
|
||||||
|
|
||||||
if self.reduction == 'mean':
|
if self.reduction == 'mean':
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
Loading…
Reference in New Issue