weight_decay fix
This commit is contained in:
parent
ff82e4d488
commit
798a7396f1
5
train.py
5
train.py
|
@ -261,9 +261,8 @@ def train():
|
||||||
print('WARNING: nan loss detected, ending training')
|
print('WARNING: nan loss detected, ending training')
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# Divide by accumulation count
|
# Scale loss by nominal batch_size of 64
|
||||||
if accumulate > 1:
|
loss *= batch_size / 64
|
||||||
loss /= accumulate
|
|
||||||
|
|
||||||
# Compute gradient
|
# Compute gradient
|
||||||
if mixed_precision:
|
if mixed_precision:
|
||||||
|
|
Loading…
Reference in New Issue