updates
This commit is contained in:
parent
ac2b9d580d
commit
c97b669c46
|
@ -174,10 +174,7 @@ class YOLOLayer(nn.Module):
|
||||||
elif arc == 'uCE': # unified CE (1 background + 80 classes)
|
elif arc == 'uCE': # unified CE (1 background + 80 classes)
|
||||||
io[..., 4:] = F.softmax(io[..., 4:], dim=4)
|
io[..., 4:] = F.softmax(io[..., 4:], dim=4)
|
||||||
io[..., 4] = 1
|
io[..., 4] = 1
|
||||||
elif arc == 'uBCE': # unified BCE (1 background + 80 classes)
|
elif arc == 'uBCE': # unified BCE (80 classes)
|
||||||
torch.sigmoid_(io[..., 4:])
|
|
||||||
io[..., 4] = 1 - io[..., 4]
|
|
||||||
elif arc == 'uBCEs': # unified BCE simplified (80 classes)
|
|
||||||
torch.sigmoid_(io[..., 5:])
|
torch.sigmoid_(io[..., 5:])
|
||||||
io[..., 4] = 1
|
io[..., 4] = 1
|
||||||
|
|
||||||
|
|
|
@ -369,12 +369,6 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
t[b, a, gj, gi, tcls[i]] = 1.0
|
t[b, a, gj, gi, tcls[i]] = 1.0
|
||||||
lcls += BCEcls(pi[..., 5:], t)
|
lcls += BCEcls(pi[..., 5:], t)
|
||||||
|
|
||||||
elif arc == 'uBCEs': # unified BCE simplified (80 classes)
|
|
||||||
t = torch.zeros_like(pi[..., 5:]) # targets
|
|
||||||
if nb:
|
|
||||||
t[b, a, gj, gi, tcls[i]] = 1.0
|
|
||||||
lcls += BCEcls(pi[..., 5:], t)
|
|
||||||
|
|
||||||
lbox *= k * h['giou']
|
lbox *= k * h['giou']
|
||||||
lobj *= k * h['obj']
|
lobj *= k * h['obj']
|
||||||
lcls *= k * h['cls']
|
lcls *= k * h['cls']
|
||||||
|
|
Loading…
Reference in New Issue