diff --git a/utils/utils.py b/utils/utils.py index b724ccbd..5bc84fff 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -436,9 +436,10 @@ def compute_loss(p, targets, model): # predictions, targets, model lcls *= h['cls'] if red == 'sum': bs = tobj.shape[0] # batch size - lbox *= 3 / ng lobj *= 3 / (6300 * bs) * 2 # 3 / np * 2 - lcls *= 3 / ng / model.nc + if ng: + lcls *= 3 / ng / model.nc + lbox *= 3 / ng loss = lbox + lobj + lcls return loss, torch.cat((lbox, lobj, lcls, loss)).detach()