weight_decay fix

This commit is contained in:
Glenn Jocher 2019-08-31 17:55:19 +02:00
parent cfb0b7e426
commit 360a32811c
1 changed files with 2 additions and 2 deletions

View File

@ -258,8 +258,8 @@ def train():
# Compute loss # Compute loss
loss, loss_items = compute_loss(pred, targets, model) loss, loss_items = compute_loss(pred, targets, model)
if torch.isnan(loss): if not torch.isfinite(loss):
print('WARNING: nan loss detected, skipping batch ', loss_items) print('WARNING: non-finite loss, skipping batch ', loss_items)
continue continue
# Scale loss by nominal batch_size of 64 # Scale loss by nominal batch_size of 64