diff --git a/utils/utils.py b/utils/utils.py index ec997b2a..929f6ba4 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -338,19 +338,25 @@ def wh_iou(wh1, wh2): class FocalLoss(nn.Module): - # Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf - # i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5) + # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): super(FocalLoss, self).__init__() - self.loss_fcn = loss_fcn + self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() self.gamma = gamma self.alpha = alpha self.reduction = loss_fcn.reduction self.loss_fcn.reduction = 'none' # required to apply FL to each element - def forward(self, input, target): - loss = self.loss_fcn(input, target) - loss *= self.alpha * (1.000001 - torch.exp(-loss)) ** self.gamma # non-zero power for gradient stability + def forward(self, pred, true): + loss = self.loss_fcn(pred, true) + # loss *= self.alpha * (1.000001 - torch.exp(-loss)) ** self.gamma # non-zero power for gradient stability + + # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py + pred_prob = torch.sigmoid(pred) # prob from logits + p_t = true * pred_prob + (1 - true) * (1 - pred_prob) + alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) + modulating_factor = (1.0 - p_t) ** self.gamma + loss = alpha_factor * modulating_factor * loss if self.reduction == 'mean': return loss.mean()