weight_decay fix

This commit is contained in:
Glenn Jocher 2019-08-29 15:44:15 +02:00
parent 7d9ffe6d4e
commit 408baf66e2
1 changed files with 3 additions and 0 deletions

View File

@ -258,6 +258,9 @@ def train():
# Compute loss # Compute loss
loss, loss_items = compute_loss(pred, targets, model) loss, loss_items = compute_loss(pred, targets, model)
if torch.isnan(loss):
print('WARNING: nan loss detected, skipping batch ', loss_items)
continue
# Scale loss by nominal batch_size of 64 # Scale loss by nominal batch_size of 64
loss *= batch_size / 64 loss *= batch_size / 64