diff --git a/utils/utils.py b/utils/utils.py index 6402f84b..1fd88e9c 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -343,13 +343,13 @@ def wh_iou(wh1, wh2): class FocalLoss(nn.Module): # Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf # i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5) - def __init__(self, loss_fcn, gamma=0.5, alpha=1, reduction='mean'): + def __init__(self, loss_fcn, gamma=0.5, alpha=1): super(FocalLoss, self).__init__() - loss_fcn.reduction = 'none' # required to apply FL to each element self.loss_fcn = loss_fcn self.gamma = gamma self.alpha = alpha - self.reduction = reduction + self.reduction = loss_fcn.reduction + self.loss_fcn.reduction = 'none' # required to apply FL to each element def forward(self, input, target): loss = self.loss_fcn(input, target)