diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 03337985..60974fe2 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -59,26 +59,3 @@ def fuse_conv_and_bn(conv, bn): fusedconv.bias.copy_(b_conv + b_bn) return fusedconv - - -class FocalLoss(nn.Module): - # Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf - # i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5) - def __init__(self, loss_fcn, alpha=1, gamma=2, reduction='mean'): - super(FocalLoss, self).__init__() - self.loss_fcn = loss_fcn - self.alpha = alpha - self.gamma = gamma - self.reduction = reduction - - def forward(self, input, target): - loss = self.loss_fcn(input, target, reduction='none') - pt = torch.exp(-loss) - loss *= self.alpha * (1 - pt) ** self.gamma - - if self.reduction == 'mean': - return loss.mean() - elif self.reduction == 'sum': - return loss.sum() - else: # 'none' - return loss diff --git a/utils/utils.py b/utils/utils.py index e3621a50..0d2e2b39 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -288,6 +288,29 @@ def wh_iou(box1, box2): return inter_area / union_area # iou +class FocalLoss(nn.Module): + # Wraps focal loss around existing loss_fcn() https://arxiv.org/pdf/1708.02002.pdf + # i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=2.5) + def __init__(self, loss_fcn, alpha=1, gamma=2, reduction='mean'): + super(FocalLoss, self).__init__() + self.loss_fcn = loss_fcn + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + + def forward(self, input, target): + loss = self.loss_fcn(input, target, reduction='none') + pt = torch.exp(-loss) + loss *= self.alpha * (1 - pt) ** self.gamma + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: # 'none' + return loss + + def compute_loss(p, targets, model, giou_loss=True): # predictions, targets, model ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor lcls, lbox, lobj = ft([0]), ft([0]), ft([0])