updates
This commit is contained in:
parent
3b1caf9a43
commit
09ff72bc7b
|
@ -436,9 +436,10 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
||||||
lcls *= h['cls']
|
lcls *= h['cls']
|
||||||
if red == 'sum':
|
if red == 'sum':
|
||||||
bs = tobj.shape[0] # batch size
|
bs = tobj.shape[0] # batch size
|
||||||
lbox *= 3 / ng
|
|
||||||
lobj *= 3 / (6300 * bs) * 2 # 3 / np * 2
|
lobj *= 3 / (6300 * bs) * 2 # 3 / np * 2
|
||||||
lcls *= 3 / ng / model.nc
|
if ng:
|
||||||
|
lcls *= 3 / ng / model.nc
|
||||||
|
lbox *= 3 / ng
|
||||||
|
|
||||||
loss = lbox + lobj + lcls
|
loss = lbox + lobj + lcls
|
||||||
return loss, torch.cat((lbox, lobj, lcls, loss)).detach()
|
return loss, torch.cat((lbox, lobj, lcls, loss)).detach()
|
||||||
|
|
Loading…
Reference in New Issue