From b3adc896f993fee3d3886b4918f0f144ada7f311 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 16 Mar 2020 21:40:57 -0700 Subject: [PATCH] focal and obj loss speed/stability --- utils/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index 4bcb8f2f..37dacaf6 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -357,7 +357,7 @@ class FocalLoss(nn.Module): p_t = true * pred_prob + (1 - true) * (1 - pred_prob) alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) modulating_factor = (1.0 - p_t) ** self.gamma - loss = alpha_factor * modulating_factor * loss + loss *= alpha_factor * modulating_factor if self.reduction == 'mean': return loss.mean() @@ -411,7 +411,7 @@ def compute_loss(p, targets, model): # predictions, targets, model pbox = torch.cat((pxy, pwh), 1) # predicted box giou = bbox_iou(pbox.t(), tbox[i], x1y1x2y2=False, GIoU=True) # giou computation lbox += (1.0 - giou).sum() if red == 'sum' else (1.0 - giou).mean() # giou loss - tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().type(tobj.dtype) # giou ratio + tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * giou.detach().clamp(0).type(tobj.dtype) # giou ratio if model.nc > 1: # cls loss (only if multiple classes) t = torch.full_like(ps[:, 5:], cn) # targets