From da40084b370e61ee6a9de219c86ee10b912ac8b6 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 17 May 2020 21:03:36 -0700 Subject: [PATCH] burnin update --- train.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index f7458e5b..3d4f355e 100644 --- a/train.py +++ b/train.py @@ -240,17 +240,16 @@ def train(hyp): targets = targets.to(device) # Burn-in - if ni <= n_burn * 2: - model.gr = np.interp(ni, [0, n_burn * 2], [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) - if ni == n_burn: # burnin complete - print_model_biases(model) - + if ni <= n_burn: + xi = [0, n_burn] # x interp + model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) + accumulate = max(1, np.interp(ni, xi, [1, 64 / batch_size]).round()) for j, x in enumerate(optimizer.param_groups): # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 - x['lr'] = np.interp(ni, [0, n_burn], [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) + x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) + x['weight_decay'] = np.interp(ni, xi, [0.0, hyp['weight_decay'] if j == 1 else 0.0]) if 'momentum' in x: - x['momentum'] = np.interp(ni, [0, n_burn], [0.9, hyp['momentum']]) - + x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']]) # Multi-Scale if opt.multi_scale: