weight_decay fix
This commit is contained in:
parent
cfb0b7e426
commit
360a32811c
4
train.py
4
train.py
|
@ -258,8 +258,8 @@ def train():
|
|||
|
||||
# Compute loss
|
||||
loss, loss_items = compute_loss(pred, targets, model)
|
||||
if torch.isnan(loss):
|
||||
print('WARNING: nan loss detected, skipping batch ', loss_items)
|
||||
if not torch.isfinite(loss):
|
||||
print('WARNING: non-finite loss, skipping batch ', loss_items)
|
||||
continue
|
||||
|
||||
# Scale loss by nominal batch_size of 64
|
||||
|
|
Loading…
Reference in New Issue