EMA class updates

This commit is contained in:
Glenn Jocher 2020-03-16 14:18:56 -07:00
parent c4047000fe
commit c09fcfc4fe
1 changed files with 2 additions and 1 deletions

View File

@ -349,7 +349,8 @@ class FocalLoss(nn.Module):
def forward(self, pred, true):
loss = self.loss_fcn(pred, true)
# loss *= self.alpha * (1.000001 - torch.exp(-loss)) ** self.gamma # non-zero power for gradient stability
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
pred_prob = torch.sigmoid(pred) # prob from logits