diff --git a/train.py b/train.py index 127adad2..5df73657 100644 --- a/train.py +++ b/train.py @@ -258,6 +258,9 @@ def train(): # Compute loss loss, loss_items = compute_loss(pred, targets, model) + if torch.isnan(loss): + print('WARNING: nan loss detected, skipping batch ', loss_items) + continue # Scale loss by nominal batch_size of 64 loss *= batch_size / 64