From 408baf66e2e56b08f08e7034ca6e383ff396d29c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 29 Aug 2019 15:44:15 +0200 Subject: [PATCH] weight_decay fix --- train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/train.py b/train.py index 127adad2..5df73657 100644 --- a/train.py +++ b/train.py @@ -258,6 +258,9 @@ def train(): # Compute loss loss, loss_items = compute_loss(pred, targets, model) + if torch.isnan(loss): + print('WARNING: nan loss detected, skipping batch ', loss_items) + continue # Scale loss by nominal batch_size of 64 loss *= batch_size / 64