This commit is contained in:
Glenn Jocher 2018-08-26 23:52:06 +02:00
parent 8944db80b9
commit 2769d79d05
1 changed files with 8 additions and 7 deletions

View File

@ -130,8 +130,9 @@ class YOLOLayer(nn.Module):
# Training # Training
if targets is not None: if targets is not None:
BCEWithLogitsLoss = nn.BCEWithLogitsLoss() BCEWithLogitsLoss1 = nn.BCEWithLogitsLoss(size_average=False)
MSELoss = nn.MSELoss() # version 0.4.0 BCEWithLogitsLoss2 = nn.BCEWithLogitsLoss(size_average=True)
MSELoss = nn.MSELoss(size_average=False) # version 0.4.0
CrossEntropyLoss = nn.CrossEntropyLoss() CrossEntropyLoss = nn.CrossEntropyLoss()
if requestPrecision: if requestPrecision:
@ -150,21 +151,21 @@ class YOLOLayer(nn.Module):
tx, ty, tw, th, mask, tcls = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda(), mask.cuda(), tcls.cuda() tx, ty, tw, th, mask, tcls = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda(), mask.cuda(), tcls.cuda()
# Mask outputs to ignore non-existing objects (but keep confidence predictions) # Mask outputs to ignore non-existing objects (but keep confidence predictions)
nM = mask.sum() nM = mask.sum().float()
nGT = sum([len(x) for x in targets]) nGT = sum([len(x) for x in targets])
if nM > 0: if nM > 0:
lx = 5 * MSELoss(x[mask], tx[mask]) lx = 5 * MSELoss(x[mask], tx[mask])
ly = 5 * MSELoss(y[mask], ty[mask]) ly = 5 * MSELoss(y[mask], ty[mask])
lw = 5 * MSELoss(w[mask], tw[mask]) lw = 5 * MSELoss(w[mask], tw[mask])
lh = 5 * MSELoss(h[mask], th[mask]) lh = 5 * MSELoss(h[mask], th[mask])
lconf = 1.5 * BCEWithLogitsLoss(pred_conf[mask], mask[mask].float()) lconf = 1.5 * BCEWithLogitsLoss1(pred_conf[mask], mask[mask].float())
lcls = 0.5 * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1)) lcls = nM * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
# lcls = BCEWithLogitsLoss(pred_cls[mask], tcls.float()) # lcls = BCEWithLogitsLoss1(pred_cls[mask], tcls.float())
else: else:
lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0]) lx, ly, lw, lh, lcls, lconf = FT([0]), FT([0]), FT([0]), FT([0]), FT([0]), FT([0])
lconf += BCEWithLogitsLoss(pred_conf[~mask], mask[~mask].float()) lconf += nM * BCEWithLogitsLoss2(pred_conf[~mask], mask[~mask].float())
loss = lx + ly + lw + lh + lconf + lcls loss = lx + ly + lw + lh + lconf + lcls
i = torch.sigmoid(pred_conf[~mask]) > 0.99 i = torch.sigmoid(pred_conf[~mask]) > 0.99