This commit is contained in:
Glenn Jocher 2019-08-20 14:38:56 +02:00
parent ac2b9d580d
commit c97b669c46
2 changed files with 1 additions and 10 deletions

View File

@ -174,10 +174,7 @@ class YOLOLayer(nn.Module):
elif arc == 'uCE': # unified CE (1 background + 80 classes) elif arc == 'uCE': # unified CE (1 background + 80 classes)
io[..., 4:] = F.softmax(io[..., 4:], dim=4) io[..., 4:] = F.softmax(io[..., 4:], dim=4)
io[..., 4] = 1 io[..., 4] = 1
elif arc == 'uBCE': # unified BCE (1 background + 80 classes) elif arc == 'uBCE': # unified BCE (80 classes)
torch.sigmoid_(io[..., 4:])
io[..., 4] = 1 - io[..., 4]
elif arc == 'uBCEs': # unified BCE simplified (80 classes)
torch.sigmoid_(io[..., 5:]) torch.sigmoid_(io[..., 5:])
io[..., 4] = 1 io[..., 4] = 1

View File

@ -369,12 +369,6 @@ def compute_loss(p, targets, model): # predictions, targets, model
t[b, a, gj, gi, tcls[i]] = 1.0 t[b, a, gj, gi, tcls[i]] = 1.0
lcls += BCEcls(pi[..., 5:], t) lcls += BCEcls(pi[..., 5:], t)
elif arc == 'uBCEs': # unified BCE simplified (80 classes)
t = torch.zeros_like(pi[..., 5:]) # targets
if nb:
t[b, a, gj, gi, tcls[i]] = 1.0
lcls += BCEcls(pi[..., 5:], t)
lbox *= k * h['giou'] lbox *= k * h['giou']
lobj *= k * h['obj'] lobj *= k * h['obj']
lcls *= k * h['cls'] lcls *= k * h['cls']