This commit is contained in:
Glenn Jocher 2019-07-24 19:02:24 +02:00
parent e1425b7288
commit 4fb7fbf4bc
2 changed files with 30 additions and 32 deletions

View File

@ -73,7 +73,7 @@ def train(cfg,
img_size=416, img_size=416,
epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs
batch_size=16, batch_size=16,
accumulate=4): # effective bs = batch_size * accumulate = 8 * 8 = 64 accumulate=4): # effective bs = batch_size * accumulate = 16 * 4 = 64
# Initialize # Initialize
init_seeds() init_seeds()
weights = 'weights' + os.sep weights = 'weights' + os.sep
@ -365,34 +365,34 @@ if __name__ == '__main__':
opt = parser.parse_args() opt = parser.parse_args()
print(opt) print(opt)
if opt.evolve: if not opt.evolve: # Train normally
results = train(opt.cfg,
opt.data,
img_size=opt.img_size,
epochs=opt.epochs,
batch_size=opt.batch_size,
accumulate=opt.accumulate)
else: # Evolve hyperparameters (optional)
opt.notest = True # only test final epoch opt.notest = True # only test final epoch
opt.nosave = True # only save final checkpoint opt.nosave = True # only save final checkpoint
if opt.bucket:
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
# Train
results = train(opt.cfg,
opt.data,
img_size=opt.img_size,
epochs=opt.epochs,
batch_size=opt.batch_size,
accumulate=opt.accumulate)
# Evolve hyperparameters (optional)
if opt.evolve:
print_mutation(hyp, results) # Write mutation results
for _ in range(1000): # generations to evolve for _ in range(1000): # generations to evolve
# Get best hyperparameters if os._exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
x = np.loadtxt('evolve.txt', ndmin=2) # Get best hyperparameters
x = x[fitness(x).argmax()] # select best fitness hyps x = np.loadtxt('evolve.txt', ndmin=2)
for i, k in enumerate(hyp.keys()): x = x[fitness(x).argmax()] # select best fitness hyps
hyp[k] = x[i + 5] for i, k in enumerate(hyp.keys()):
hyp[k] = x[i + 5]
# Mutate # Mutate
init_seeds(seed=int(time.time())) init_seeds(seed=int(time.time()))
s = [.15, .15, .15, .15, .15, .15, .15, .15, .15, .00, .05, .20, .20, .20, .20, .20, .20, .20] # sigmas s = [.15, .15, .15, .15, .15, .15, .15, .15, .15, .00, .05, .20, .20, .20, .20, .20, .20, .20] # sigmas
for i, k in enumerate(hyp.keys()): for i, k in enumerate(hyp.keys()):
x = (np.random.randn(1) * s[i] + 1) ** 2.0 # plt.hist(x.ravel(), 300) x = (np.random.randn(1) * s[i] + 1) ** 2.0 # plt.hist(x.ravel(), 300)
hyp[k] *= float(x) # vary by sigmas hyp[k] *= float(x) # vary by sigmas
# Clip to limits # Clip to limits
keys = ['lr0', 'iou_t', 'momentum', 'weight_decay', 'hsv_s', 'hsv_v', 'translate', 'scale'] keys = ['lr0', 'iou_t', 'momentum', 'weight_decay', 'hsv_s', 'hsv_v', 'translate', 'scale']

View File

@ -37,14 +37,12 @@ def fuse_conv_and_bn(conv, bn):
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/ # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
with torch.no_grad(): with torch.no_grad():
# init # init
fusedconv = torch.nn.Conv2d( fusedconv = torch.nn.Conv2d(conv.in_channels,
conv.in_channels, conv.out_channels,
conv.out_channels, kernel_size=conv.kernel_size,
kernel_size=conv.kernel_size, stride=conv.stride,
stride=conv.stride, padding=conv.padding,
padding=conv.padding, bias=True)
bias=True
)
# prepare filters # prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1) w_conv = conv.weight.clone().view(conv.out_channels, -1)