EMA class updates
This commit is contained in:
parent
c4047000fe
commit
c09fcfc4fe
|
@ -349,7 +349,8 @@ class FocalLoss(nn.Module):
|
|||
|
||||
def forward(self, pred, true):
|
||||
loss = self.loss_fcn(pred, true)
|
||||
# loss *= self.alpha * (1.000001 - torch.exp(-loss)) ** self.gamma # non-zero power for gradient stability
|
||||
# p_t = torch.exp(-loss)
|
||||
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
||||
|
||||
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
|
||||
pred_prob = torch.sigmoid(pred) # prob from logits
|
||||
|
|
Loading…
Reference in New Issue