@ -258,6 +258,9 @@ def train():
# Compute loss
loss, loss_items = compute_loss(pred, targets, model)
if torch.isnan(loss):
print('WARNING: nan loss detected, skipping batch ', loss_items)
continue
# Scale loss by nominal batch_size of 64
loss *= batch_size / 64