weight_decay fix

This commit is contained in:
Glenn Jocher 2019-08-26 16:24:19 +02:00
parent ff82e4d488
commit 798a7396f1
1 changed files with 2 additions and 3 deletions

View File

@ -261,9 +261,8 @@ def train():
print('WARNING: nan loss detected, ending training') print('WARNING: nan loss detected, ending training')
return results return results
# Divide by accumulation count # Scale loss by nominal batch_size of 64
if accumulate > 1: loss *= batch_size / 64
loss /= accumulate
# Compute gradient # Compute gradient
if mixed_precision: if mixed_precision: