diff --git a/train.py b/train.py index 9cd25a39..85dd8c6b 100644 --- a/train.py +++ b/train.py @@ -238,12 +238,13 @@ def train(): targets = targets.to(device) # Hyperparameter Burn-in - n_burn = 100 # number of burn-in batches + n_burn = 200 # number of burn-in batches if ni <= n_burn: g = (ni / n_burn) ** 4 # gain for x in model.named_modules(): if x[0].endswith('BatchNorm2d'): - x[1].momentum = 1 - 0.9 * g # momentum falls from 1 - 0.1 + # x[1].momentum = 1 - 0.9 * g # momentum falls from 1 - 0.1 + x[1].track_running_stats = ni == n_burn for x in optimizer.param_groups: x['lr'] = x['initial_lr'] * g # gain rises from 0 - 1