This commit is contained in:
Glenn Jocher 2019-08-08 20:16:32 +02:00
parent c49fe688b7
commit 25bc5e5392
1 changed files with 6 additions and 5 deletions

View File

@ -215,8 +215,9 @@ def train(cfg,
targets = targets.to(device)
# Multi-Scale training
ni = (i + nb * epoch) # number integrated batches (since train start)
if multi_scale:
if (i + nb * epoch) / accumulate % 10 == 0: #  adjust (67% - 150%) every 10 batches
if ni / accumulate % 10 == 0: #  adjust (67% - 150%) every 10 batches
img_size = random.randrange(img_sz_min, img_sz_max + 1) * 32
sf = img_size / max(imgs.shape[2:]) # scale factor
if sf != 1:
@ -228,12 +229,12 @@ def train(cfg,
plot_images(imgs=imgs, targets=targets, paths=paths, fname='train_batch%g.jpg' % i)
# Hyperparameter burn-in
# n_burnin = min(round(nb / 5 + 1), 1000) # burn-in batches
# if epoch == 0 and i <= n_burnin:
# n_burn = min(nb // 5 + 1, 1000) # number of burn-in batches
# if ni <= n_burn:
# for m in model.named_modules():
# if m[0].endswith('BatchNorm2d'):
# m[1].momentum = 1 - i / n_burnin * 0.99 # BatchNorm2d momentum falls from 1 - 0.01
# g = (i / n_burnin) ** 4 # gain rises from 0 - 1
# m[1].momentum = 1 - i / n_burn * 0.99 # BatchNorm2d momentum falls from 1 - 0.01
# g = (i / n_burn) ** 4 # gain rises from 0 - 1
# for x in optimizer.param_groups:
# x['lr'] = hyp['lr0'] * g
# x['weight_decay'] = hyp['weight_decay'] * g