From 360a32811c2c2eee3f182c63f90abc328a809a7e Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 31 Aug 2019 17:55:19 +0200 Subject: [PATCH] weight_decay fix --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index e85d9343..9a6567ba 100644 --- a/train.py +++ b/train.py @@ -258,8 +258,8 @@ def train(): # Compute loss loss, loss_items = compute_loss(pred, targets, model) - if torch.isnan(loss): - print('WARNING: nan loss detected, skipping batch ', loss_items) + if not torch.isfinite(loss): + print('WARNING: non-finite loss, skipping batch ', loss_items) continue # Scale loss by nominal batch_size of 64