weight_decay fix

This commit is contained in:
Glenn Jocher 2019-08-26 16:24:19 +02:00
parent ff82e4d488
commit 798a7396f1
1 changed files with 2 additions and 3 deletions

View File

@ -261,9 +261,8 @@ def train():
print('WARNING: nan loss detected, ending training')
return results
# Divide by accumulation count
if accumulate > 1:
loss /= accumulate
# Scale loss by nominal batch_size of 64
loss *= batch_size / 64
# Compute gradient
if mixed_precision: