updates
This commit is contained in:
parent
3527e61526
commit
fd2991386f
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def init_seeds(seed=0):
|
||||
|
@ -58,3 +59,26 @@ def fuse_conv_and_bn(conv, bn):
|
|||
fusedconv.bias.copy_(b_conv + b_bn)
|
||||
|
||||
return fusedconv
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
# Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf
|
||||
# i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5)
|
||||
def __init__(self, loss_fcn, alpha=1, gamma=2, reduction='mean'):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.loss_fcn = loss_fcn
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, input, target):
|
||||
loss = self.loss_fcn(input, target, reduction='none')
|
||||
pt = torch.exp(-loss)
|
||||
loss *= self.alpha * (1 - pt) ** self.gamma
|
||||
|
||||
if self.reduction == 'mean':
|
||||
return loss.mean()
|
||||
elif self.reduction == 'sum':
|
||||
return loss.sum()
|
||||
else: # 'none'
|
||||
return loss
|
||||
|
|
Loading…
Reference in New Issue