weight_decay fix
This commit is contained in:
parent
7d9ffe6d4e
commit
408baf66e2
3
train.py
3
train.py
|
@ -258,6 +258,9 @@ def train():
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
loss, loss_items = compute_loss(pred, targets, model)
|
loss, loss_items = compute_loss(pred, targets, model)
|
||||||
|
if torch.isnan(loss):
|
||||||
|
print('WARNING: nan loss detected, skipping batch ', loss_items)
|
||||||
|
continue
|
||||||
|
|
||||||
# Scale loss by nominal batch_size of 64
|
# Scale loss by nominal batch_size of 64
|
||||||
loss *= batch_size / 64
|
loss *= batch_size / 64
|
||||||
|
|
Loading…
Reference in New Issue