updates
This commit is contained in:
parent
e1425b7288
commit
4fb7fbf4bc
48
train.py
48
train.py
|
@ -73,7 +73,7 @@ def train(cfg,
|
|||
img_size=416,
|
||||
epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs
|
||||
batch_size=16,
|
||||
accumulate=4): # effective bs = batch_size * accumulate = 8 * 8 = 64
|
||||
accumulate=4): # effective bs = batch_size * accumulate = 16 * 4 = 64
|
||||
# Initialize
|
||||
init_seeds()
|
||||
weights = 'weights' + os.sep
|
||||
|
@ -365,34 +365,34 @@ if __name__ == '__main__':
|
|||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
|
||||
if opt.evolve:
|
||||
if not opt.evolve: # Train normally
|
||||
results = train(opt.cfg,
|
||||
opt.data,
|
||||
img_size=opt.img_size,
|
||||
epochs=opt.epochs,
|
||||
batch_size=opt.batch_size,
|
||||
accumulate=opt.accumulate)
|
||||
|
||||
else: # Evolve hyperparameters (optional)
|
||||
opt.notest = True # only test final epoch
|
||||
opt.nosave = True # only save final checkpoint
|
||||
if opt.bucket:
|
||||
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
|
||||
|
||||
# Train
|
||||
results = train(opt.cfg,
|
||||
opt.data,
|
||||
img_size=opt.img_size,
|
||||
epochs=opt.epochs,
|
||||
batch_size=opt.batch_size,
|
||||
accumulate=opt.accumulate)
|
||||
|
||||
# Evolve hyperparameters (optional)
|
||||
if opt.evolve:
|
||||
print_mutation(hyp, results) # Write mutation results
|
||||
for _ in range(1000): # generations to evolve
|
||||
# Get best hyperparameters
|
||||
x = np.loadtxt('evolve.txt', ndmin=2)
|
||||
x = x[fitness(x).argmax()] # select best fitness hyps
|
||||
for i, k in enumerate(hyp.keys()):
|
||||
hyp[k] = x[i + 5]
|
||||
if os._exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
|
||||
# Get best hyperparameters
|
||||
x = np.loadtxt('evolve.txt', ndmin=2)
|
||||
x = x[fitness(x).argmax()] # select best fitness hyps
|
||||
for i, k in enumerate(hyp.keys()):
|
||||
hyp[k] = x[i + 5]
|
||||
|
||||
# Mutate
|
||||
init_seeds(seed=int(time.time()))
|
||||
s = [.15, .15, .15, .15, .15, .15, .15, .15, .15, .00, .05, .20, .20, .20, .20, .20, .20, .20] # sigmas
|
||||
for i, k in enumerate(hyp.keys()):
|
||||
x = (np.random.randn(1) * s[i] + 1) ** 2.0 # plt.hist(x.ravel(), 300)
|
||||
hyp[k] *= float(x) # vary by sigmas
|
||||
# Mutate
|
||||
init_seeds(seed=int(time.time()))
|
||||
s = [.15, .15, .15, .15, .15, .15, .15, .15, .15, .00, .05, .20, .20, .20, .20, .20, .20, .20] # sigmas
|
||||
for i, k in enumerate(hyp.keys()):
|
||||
x = (np.random.randn(1) * s[i] + 1) ** 2.0 # plt.hist(x.ravel(), 300)
|
||||
hyp[k] *= float(x) # vary by sigmas
|
||||
|
||||
# Clip to limits
|
||||
keys = ['lr0', 'iou_t', 'momentum', 'weight_decay', 'hsv_s', 'hsv_v', 'translate', 'scale']
|
||||
|
|
|
@ -37,14 +37,12 @@ def fuse_conv_and_bn(conv, bn):
|
|||
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
||||
with torch.no_grad():
|
||||
# init
|
||||
fusedconv = torch.nn.Conv2d(
|
||||
conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
bias=True
|
||||
)
|
||||
fusedconv = torch.nn.Conv2d(conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
bias=True)
|
||||
|
||||
# prepare filters
|
||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||
|
|
Loading…
Reference in New Issue