updates
This commit is contained in:
parent
ac2b9d580d
commit
c97b669c46
|
@ -174,10 +174,7 @@ class YOLOLayer(nn.Module):
|
|||
elif arc == 'uCE': # unified CE (1 background + 80 classes)
|
||||
io[..., 4:] = F.softmax(io[..., 4:], dim=4)
|
||||
io[..., 4] = 1
|
||||
elif arc == 'uBCE': # unified BCE (1 background + 80 classes)
|
||||
torch.sigmoid_(io[..., 4:])
|
||||
io[..., 4] = 1 - io[..., 4]
|
||||
elif arc == 'uBCEs': # unified BCE simplified (80 classes)
|
||||
elif arc == 'uBCE': # unified BCE (80 classes)
|
||||
torch.sigmoid_(io[..., 5:])
|
||||
io[..., 4] = 1
|
||||
|
||||
|
|
|
@ -369,12 +369,6 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
t[b, a, gj, gi, tcls[i]] = 1.0
|
||||
lcls += BCEcls(pi[..., 5:], t)
|
||||
|
||||
elif arc == 'uBCEs': # unified BCE simplified (80 classes)
|
||||
t = torch.zeros_like(pi[..., 5:]) # targets
|
||||
if nb:
|
||||
t[b, a, gj, gi, tcls[i]] = 1.0
|
||||
lcls += BCEcls(pi[..., 5:], t)
|
||||
|
||||
lbox *= k * h['giou']
|
||||
lobj *= k * h['obj']
|
||||
lcls *= k * h['cls']
|
||||
|
|
Loading…
Reference in New Issue