diff --git a/train.py b/train.py index d4bfc08d..9cd25a39 100644 --- a/train.py +++ b/train.py @@ -237,11 +237,15 @@ def train(): imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 targets = targets.to(device) - # Hyperparameter burn-in - # n_burn = 100 # number of burn-in batches - # if ni < n_burn: - # for x in optimizer.param_groups: - # x['lr'] = x['initial_lr'] * (ni / n_burn) ** 4 # gain rises from 0 - 1 + # Hyperparameter Burn-in + n_burn = 100 # number of burn-in batches + if ni <= n_burn: + g = (ni / n_burn) ** 4 # gain + for x in model.named_modules(): + if x[0].endswith('BatchNorm2d'): + x[1].momentum = 1 - 0.9 * g # momentum falls from 1 - 0.1 + for x in optimizer.param_groups: + x['lr'] = x['initial_lr'] * g # gain rises from 0 - 1 # Plot images with bounding boxes if ni < 1: