FocalLoss() updated to match TF
This commit is contained in:
parent
07d2f0ad03
commit
c4047000fe
|
@ -338,19 +338,25 @@ def wh_iou(wh1, wh2):
|
||||||
|
|
||||||
|
|
||||||
class FocalLoss(nn.Module):
|
class FocalLoss(nn.Module):
|
||||||
# Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf
|
# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
|
||||||
# i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5)
|
|
||||||
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
||||||
super(FocalLoss, self).__init__()
|
super(FocalLoss, self).__init__()
|
||||||
self.loss_fcn = loss_fcn
|
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.reduction = loss_fcn.reduction
|
self.reduction = loss_fcn.reduction
|
||||||
self.loss_fcn.reduction = 'none' # required to apply FL to each element
|
self.loss_fcn.reduction = 'none' # required to apply FL to each element
|
||||||
|
|
||||||
def forward(self, input, target):
|
def forward(self, pred, true):
|
||||||
loss = self.loss_fcn(input, target)
|
loss = self.loss_fcn(pred, true)
|
||||||
loss *= self.alpha * (1.000001 - torch.exp(-loss)) ** self.gamma # non-zero power for gradient stability
|
# loss *= self.alpha * (1.000001 - torch.exp(-loss)) ** self.gamma # non-zero power for gradient stability
|
||||||
|
|
||||||
|
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
|
||||||
|
pred_prob = torch.sigmoid(pred) # prob from logits
|
||||||
|
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
|
||||||
|
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
|
||||||
|
modulating_factor = (1.0 - p_t) ** self.gamma
|
||||||
|
loss = alpha_factor * modulating_factor * loss
|
||||||
|
|
||||||
if self.reduction == 'mean':
|
if self.reduction == 'mean':
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
Loading…
Reference in New Issue