diff --git a/train.py b/train.py index e85d9343..9a6567ba 100644 --- a/train.py +++ b/train.py @@ -258,8 +258,8 @@ def train(): # Compute loss loss, loss_items = compute_loss(pred, targets, model) - if torch.isnan(loss): - print('WARNING: nan loss detected, skipping batch ', loss_items) + if not torch.isfinite(loss): + print('WARNING: non-finite loss, skipping batch ', loss_items) continue # Scale loss by nominal batch_size of 64