This commit is contained in:
Glenn Jocher 2019-08-24 23:58:08 +02:00
parent 3ee457cd3d
commit a85f7d967c
2 changed files with 7 additions and 5 deletions

View File

@ -249,6 +249,10 @@ def train():
print('WARNING: nan loss detected, ending training')
return results
# Divide by accumulation count
if accumulate > 1:
loss /= accumulate
# Compute gradient
if mixed_precision:
with amp.scale_loss(loss, optimizer) as scaled_loss:

View File

@ -326,8 +326,6 @@ def compute_loss(p, targets, model): # predictions, targets, model
FCE = FocalLoss(nn.CrossEntropyLoss()) # weight=model.class_weights
# Compute losses
bs = p[0].shape[0] # batch size
k = bs / 64 # loss gain
for i, pi in enumerate(p): # layer index, layer predictions
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tobj = torch.zeros_like(pi[..., 0]) # target obj
@ -370,9 +368,9 @@ def compute_loss(p, targets, model): # predictions, targets, model
t[b, a, gj, gi, tcls[i]] = 1.0
lobj += FBCE(pi[..., 5:], t)
lbox *= k * h['giou']
lobj *= k * h['obj']
lcls *= k * h['cls']
lbox *= h['giou']
lobj *= h['obj']
lcls *= h['cls']
loss = lbox + lobj + lcls
return loss, torch.cat((lbox, lobj, lcls, loss)).detach()