burnin lr ramp 300 iterations
This commit is contained in:
parent
ac2aa56e0a
commit
b2d9f1898f
27
train.py
27
train.py
|
@ -241,23 +241,16 @@ def train():
|
||||||
targets = targets.to(device)
|
targets = targets.to(device)
|
||||||
|
|
||||||
# Hyperparameter Burn-in
|
# Hyperparameter Burn-in
|
||||||
n_burn = 200 # number of burn-in batches
|
n_burn = 300 # number of burn-in batches
|
||||||
if ni <= n_burn:
|
if ni <= n_burn:
|
||||||
# g = (ni / n_burn) ** 2 # gain
|
g = (ni / n_burn) ** 2 # gain
|
||||||
for x in model.named_modules(): # initial stats may be poor, wait to track
|
for x in model.named_modules(): # initial stats may be poor, wait to track
|
||||||
if x[0].endswith('BatchNorm2d'):
|
if x[0].endswith('BatchNorm2d'):
|
||||||
x[1].track_running_stats = ni == n_burn
|
x[1].track_running_stats = ni == n_burn
|
||||||
# for x in optimizer.param_groups:
|
for x in optimizer.param_groups:
|
||||||
# x['lr'] = x['initial_lr'] * lf(epoch) * g # gain rises from 0 - 1
|
x['lr'] = x['initial_lr'] * lf(epoch) * g # gain rises from 0 - 1
|
||||||
# if 'momentum' in x:
|
if 'momentum' in x:
|
||||||
# x['momentum'] = hyp['momentum'] * g
|
x['momentum'] = hyp['momentum'] * g
|
||||||
|
|
||||||
# Plot images with bounding boxes
|
|
||||||
if ni < 1:
|
|
||||||
f = 'train_batch%g.png' % i # filename
|
|
||||||
plot_images(imgs=imgs, targets=targets, paths=paths, fname=f)
|
|
||||||
if tb_writer:
|
|
||||||
tb_writer.add_image(f, cv2.imread(f)[:, :, ::-1], dataformats='HWC')
|
|
||||||
|
|
||||||
# Multi-Scale training
|
# Multi-Scale training
|
||||||
if opt.multi_scale:
|
if opt.multi_scale:
|
||||||
|
@ -299,6 +292,14 @@ def train():
|
||||||
s = ('%10s' * 2 + '%10.3g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem, *mloss, len(targets), img_size)
|
s = ('%10s' * 2 + '%10.3g' * 6) % ('%g/%g' % (epoch, epochs - 1), mem, *mloss, len(targets), img_size)
|
||||||
pbar.set_description(s)
|
pbar.set_description(s)
|
||||||
|
|
||||||
|
# Plot images with bounding boxes
|
||||||
|
if ni < 1:
|
||||||
|
f = 'train_batch%g.png' % i # filename
|
||||||
|
plot_images(imgs=imgs, targets=targets, paths=paths, fname=f)
|
||||||
|
if tb_writer:
|
||||||
|
tb_writer.add_image(f, cv2.imread(f)[:, :, ::-1], dataformats='HWC')
|
||||||
|
# tb_writer.add_graph(model, imgs)
|
||||||
|
|
||||||
# end batch ------------------------------------------------------------------------------------------------
|
# end batch ------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
# Update scheduler
|
# Update scheduler
|
||||||
|
|
Loading…
Reference in New Issue