diff --git a/train.py b/train.py index c005480a..938e19e5 100644 --- a/train.py +++ b/train.py @@ -259,7 +259,7 @@ def train(): # Compute loss loss, loss_items = compute_loss(pred, targets, model) if torch.isnan(loss): - print('WARNING: nan loss detected, ending training') + print('WARNING: nan loss detected, ending training', loss_items) return results # Scale loss by nominal batch_size of 64