updates
This commit is contained in:
parent
6ca8277de2
commit
41bf46a419
2
train.py
2
train.py
|
@ -240,7 +240,7 @@ def train():
|
||||||
# Hyperparameter Burn-in
|
# Hyperparameter Burn-in
|
||||||
n_burn = 200 # number of burn-in batches
|
n_burn = 200 # number of burn-in batches
|
||||||
if ni <= n_burn:
|
if ni <= n_burn:
|
||||||
g = (ni / n_burn) ** 2 # gain
|
g = ni / n_burn # gain
|
||||||
for x in model.named_modules():
|
for x in model.named_modules():
|
||||||
if x[0].endswith('BatchNorm2d'):
|
if x[0].endswith('BatchNorm2d'):
|
||||||
# x[1].momentum = 1 - 0.9 * g # momentum falls from 1 - 0.1
|
# x[1].momentum = 1 - 0.9 * g # momentum falls from 1 - 0.1
|
||||||
|
|
Loading…
Reference in New Issue