burnin update

This commit is contained in:
Glenn Jocher 2020-05-17 21:03:36 -07:00
parent 0c7d7427e4
commit da40084b37
1 changed files with 7 additions and 8 deletions

View File

@ -240,17 +240,16 @@ def train(hyp):
targets = targets.to(device) targets = targets.to(device)
# Burn-in # Burn-in
if ni <= n_burn * 2: if ni <= n_burn:
model.gr = np.interp(ni, [0, n_burn * 2], [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou) xi = [0, n_burn] # x interp
if ni == n_burn: # burnin complete model.gr = np.interp(ni, xi, [0.0, 1.0]) # giou loss ratio (obj_loss = 1.0 or giou)
print_model_biases(model) accumulate = max(1, np.interp(ni, xi, [1, 64 / batch_size]).round())
for j, x in enumerate(optimizer.param_groups): for j, x in enumerate(optimizer.param_groups):
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x['lr'] = np.interp(ni, [0, n_burn], [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)]) x['lr'] = np.interp(ni, xi, [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
x['weight_decay'] = np.interp(ni, xi, [0.0, hyp['weight_decay'] if j == 1 else 0.0])
if 'momentum' in x: if 'momentum' in x:
x['momentum'] = np.interp(ni, [0, n_burn], [0.9, hyp['momentum']]) x['momentum'] = np.interp(ni, xi, [0.9, hyp['momentum']])
# Multi-Scale # Multi-Scale
if opt.multi_scale: if opt.multi_scale: