updates
This commit is contained in:
parent
e40d4c87f2
commit
e76d4d0ffc
5
train.py
5
train.py
|
@ -238,12 +238,13 @@ def train():
|
|||
targets = targets.to(device)
|
||||
|
||||
# Hyperparameter Burn-in
|
||||
n_burn = 100 # number of burn-in batches
|
||||
n_burn = 200 # number of burn-in batches
|
||||
if ni <= n_burn:
|
||||
g = (ni / n_burn) ** 4 # gain
|
||||
for x in model.named_modules():
|
||||
if x[0].endswith('BatchNorm2d'):
|
||||
x[1].momentum = 1 - 0.9 * g # momentum falls from 1 - 0.1
|
||||
# x[1].momentum = 1 - 0.9 * g # momentum falls from 1 - 0.1
|
||||
x[1].track_running_stats = ni == n_burn
|
||||
for x in optimizer.param_groups:
|
||||
x['lr'] = x['initial_lr'] * g # gain rises from 0 - 1
|
||||
|
||||
|
|
Loading…
Reference in New Issue