updates
This commit is contained in:
parent
3ee457cd3d
commit
a85f7d967c
4
train.py
4
train.py
|
@ -249,6 +249,10 @@ def train():
|
|||
print('WARNING: nan loss detected, ending training')
|
||||
return results
|
||||
|
||||
# Divide by accumulation count
|
||||
if accumulate > 1:
|
||||
loss /= accumulate
|
||||
|
||||
# Compute gradient
|
||||
if mixed_precision:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
|
|
|
@ -326,8 +326,6 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
FCE = FocalLoss(nn.CrossEntropyLoss()) # weight=model.class_weights
|
||||
|
||||
# Compute losses
|
||||
bs = p[0].shape[0] # batch size
|
||||
k = bs / 64 # loss gain
|
||||
for i, pi in enumerate(p): # layer index, layer predictions
|
||||
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
||||
tobj = torch.zeros_like(pi[..., 0]) # target obj
|
||||
|
@ -370,9 +368,9 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|||
t[b, a, gj, gi, tcls[i]] = 1.0
|
||||
lobj += FBCE(pi[..., 5:], t)
|
||||
|
||||
lbox *= k * h['giou']
|
||||
lobj *= k * h['obj']
|
||||
lcls *= k * h['cls']
|
||||
lbox *= h['giou']
|
||||
lobj *= h['obj']
|
||||
lcls *= h['cls']
|
||||
loss = lbox + lobj + lcls
|
||||
return loss, torch.cat((lbox, lobj, lcls, loss)).detach()
|
||||
|
||||
|
|
Loading…
Reference in New Issue