updates
This commit is contained in:
parent
e1425b7288
commit
4fb7fbf4bc
48
train.py
48
train.py
|
@ -73,7 +73,7 @@ def train(cfg,
|
||||||
img_size=416,
|
img_size=416,
|
||||||
epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs
|
epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
accumulate=4): # effective bs = batch_size * accumulate = 8 * 8 = 64
|
accumulate=4): # effective bs = batch_size * accumulate = 16 * 4 = 64
|
||||||
# Initialize
|
# Initialize
|
||||||
init_seeds()
|
init_seeds()
|
||||||
weights = 'weights' + os.sep
|
weights = 'weights' + os.sep
|
||||||
|
@ -365,34 +365,34 @@ if __name__ == '__main__':
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt)
|
print(opt)
|
||||||
|
|
||||||
if opt.evolve:
|
if not opt.evolve: # Train normally
|
||||||
|
results = train(opt.cfg,
|
||||||
|
opt.data,
|
||||||
|
img_size=opt.img_size,
|
||||||
|
epochs=opt.epochs,
|
||||||
|
batch_size=opt.batch_size,
|
||||||
|
accumulate=opt.accumulate)
|
||||||
|
|
||||||
|
else: # Evolve hyperparameters (optional)
|
||||||
opt.notest = True # only test final epoch
|
opt.notest = True # only test final epoch
|
||||||
opt.nosave = True # only save final checkpoint
|
opt.nosave = True # only save final checkpoint
|
||||||
|
if opt.bucket:
|
||||||
|
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
|
||||||
|
|
||||||
# Train
|
|
||||||
results = train(opt.cfg,
|
|
||||||
opt.data,
|
|
||||||
img_size=opt.img_size,
|
|
||||||
epochs=opt.epochs,
|
|
||||||
batch_size=opt.batch_size,
|
|
||||||
accumulate=opt.accumulate)
|
|
||||||
|
|
||||||
# Evolve hyperparameters (optional)
|
|
||||||
if opt.evolve:
|
|
||||||
print_mutation(hyp, results) # Write mutation results
|
|
||||||
for _ in range(1000): # generations to evolve
|
for _ in range(1000): # generations to evolve
|
||||||
# Get best hyperparameters
|
if os._exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
|
||||||
x = np.loadtxt('evolve.txt', ndmin=2)
|
# Get best hyperparameters
|
||||||
x = x[fitness(x).argmax()] # select best fitness hyps
|
x = np.loadtxt('evolve.txt', ndmin=2)
|
||||||
for i, k in enumerate(hyp.keys()):
|
x = x[fitness(x).argmax()] # select best fitness hyps
|
||||||
hyp[k] = x[i + 5]
|
for i, k in enumerate(hyp.keys()):
|
||||||
|
hyp[k] = x[i + 5]
|
||||||
|
|
||||||
# Mutate
|
# Mutate
|
||||||
init_seeds(seed=int(time.time()))
|
init_seeds(seed=int(time.time()))
|
||||||
s = [.15, .15, .15, .15, .15, .15, .15, .15, .15, .00, .05, .20, .20, .20, .20, .20, .20, .20] # sigmas
|
s = [.15, .15, .15, .15, .15, .15, .15, .15, .15, .00, .05, .20, .20, .20, .20, .20, .20, .20] # sigmas
|
||||||
for i, k in enumerate(hyp.keys()):
|
for i, k in enumerate(hyp.keys()):
|
||||||
x = (np.random.randn(1) * s[i] + 1) ** 2.0 # plt.hist(x.ravel(), 300)
|
x = (np.random.randn(1) * s[i] + 1) ** 2.0 # plt.hist(x.ravel(), 300)
|
||||||
hyp[k] *= float(x) # vary by sigmas
|
hyp[k] *= float(x) # vary by sigmas
|
||||||
|
|
||||||
# Clip to limits
|
# Clip to limits
|
||||||
keys = ['lr0', 'iou_t', 'momentum', 'weight_decay', 'hsv_s', 'hsv_v', 'translate', 'scale']
|
keys = ['lr0', 'iou_t', 'momentum', 'weight_decay', 'hsv_s', 'hsv_v', 'translate', 'scale']
|
||||||
|
|
|
@ -37,14 +37,12 @@ def fuse_conv_and_bn(conv, bn):
|
||||||
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# init
|
# init
|
||||||
fusedconv = torch.nn.Conv2d(
|
fusedconv = torch.nn.Conv2d(conv.in_channels,
|
||||||
conv.in_channels,
|
conv.out_channels,
|
||||||
conv.out_channels,
|
kernel_size=conv.kernel_size,
|
||||||
kernel_size=conv.kernel_size,
|
stride=conv.stride,
|
||||||
stride=conv.stride,
|
padding=conv.padding,
|
||||||
padding=conv.padding,
|
bias=True)
|
||||||
bias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# prepare filters
|
# prepare filters
|
||||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||||
|
|
Loading…
Reference in New Issue