This commit is contained in:
Glenn Jocher 2020-01-31 09:00:45 -08:00
parent 0c7af1a4d2
commit 189c7044fb
1 changed files with 3 additions and 3 deletions

View File

@ -343,13 +343,13 @@ def wh_iou(wh1, wh2):
class FocalLoss(nn.Module): class FocalLoss(nn.Module):
# Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf # Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf
# i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5) # i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5)
def __init__(self, loss_fcn, gamma=0.5, alpha=1, reduction='mean'): def __init__(self, loss_fcn, gamma=0.5, alpha=1):
super(FocalLoss, self).__init__() super(FocalLoss, self).__init__()
loss_fcn.reduction = 'none' # required to apply FL to each element
self.loss_fcn = loss_fcn self.loss_fcn = loss_fcn
self.gamma = gamma self.gamma = gamma
self.alpha = alpha self.alpha = alpha
self.reduction = reduction self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = 'none' # required to apply FL to each element
def forward(self, input, target): def forward(self, input, target):
loss = self.loss_fcn(input, target) loss = self.loss_fcn(input, target)