diff --git a/models.py b/models.py index f836346b..06a18a0b 100755 --- a/models.py +++ b/models.py @@ -130,8 +130,9 @@ class YOLOLayer(nn.Module): # Training if targets is not None: - BCEWithLogitsLoss = nn.BCEWithLogitsLoss() - MSELoss = nn.MSELoss() # version 0.4.0 + BCEWithLogitsLoss1 = nn.BCEWithLogitsLoss(size_average=False) + BCEWithLogitsLoss2 = nn.BCEWithLogitsLoss(size_average=True) + MSELoss = nn.MSELoss(size_average=False) # version 0.4.0 CrossEntropyLoss = nn.CrossEntropyLoss() if requestPrecision: @@ -150,21 +151,21 @@ class YOLOLayer(nn.Module): tx, ty, tw, th, mask, tcls = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda(), mask.cuda(), tcls.cuda() # Mask outputs to ignore non-existing objects (but keep confidence predictions) - nM = mask.sum() + nM = mask.sum().float() nGT = sum([len(x) for x in targets]) if nM > 0: lx = 5 * MSELoss(x[mask], tx[mask]) ly = 5 * MSELoss(y[mask], ty[mask]) lw = 5 * MSELoss(w[mask], tw[mask]) lh = 5 * MSELoss(h[mask], th[mask]) - lconf = 1.5 * BCEWithLogitsLoss(pred_conf[mask], mask[mask].float()) + lconf = 1.5 * BCEWithLogitsLoss1(pred_conf[mask], mask[mask].float()) - lcls = 0.5 * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1)) - # lcls = BCEWithLogitsLoss(pred_cls[mask], tcls.float()) + lcls = nM * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1)) + # lcls = BCEWithLogitsLoss1(pred_cls[mask], tcls.float()) else: lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0]) - lconf += BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float()) + lconf += nM * BCEWithLogitsLoss2(pred_conf[~mask], mask[~mask].float()) loss = lx + ly + lw + lh + lconf + lcls i = torch.sigmoid(pred_conf[~mask]) > 0.99