@ -261,9 +261,8 @@ def train():
print('WARNING: nan loss detected, ending training')
return results
# Divide by accumulation count
if accumulate > 1:
loss /= accumulate
# Scale loss by nominal batch_size of 64
loss *= batch_size / 64
# Compute gradient
if mixed_precision: