updates
This commit is contained in:
parent
e1425b7288
commit
4fb7fbf4bc
18
train.py
18
train.py
|
@ -73,7 +73,7 @@ def train(cfg,
|
||||||
img_size=416,
|
img_size=416,
|
||||||
epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs
|
epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
accumulate=4): # effective bs = batch_size * accumulate = 8 * 8 = 64
|
accumulate=4): # effective bs = batch_size * accumulate = 16 * 4 = 64
|
||||||
# Initialize
|
# Initialize
|
||||||
init_seeds()
|
init_seeds()
|
||||||
weights = 'weights' + os.sep
|
weights = 'weights' + os.sep
|
||||||
|
@ -365,11 +365,7 @@ if __name__ == '__main__':
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt)
|
print(opt)
|
||||||
|
|
||||||
if opt.evolve:
|
if not opt.evolve: # Train normally
|
||||||
opt.notest = True # only test final epoch
|
|
||||||
opt.nosave = True # only save final checkpoint
|
|
||||||
|
|
||||||
# Train
|
|
||||||
results = train(opt.cfg,
|
results = train(opt.cfg,
|
||||||
opt.data,
|
opt.data,
|
||||||
img_size=opt.img_size,
|
img_size=opt.img_size,
|
||||||
|
@ -377,10 +373,14 @@ if __name__ == '__main__':
|
||||||
batch_size=opt.batch_size,
|
batch_size=opt.batch_size,
|
||||||
accumulate=opt.accumulate)
|
accumulate=opt.accumulate)
|
||||||
|
|
||||||
# Evolve hyperparameters (optional)
|
else: # Evolve hyperparameters (optional)
|
||||||
if opt.evolve:
|
opt.notest = True # only test final epoch
|
||||||
print_mutation(hyp, results) # Write mutation results
|
opt.nosave = True # only save final checkpoint
|
||||||
|
if opt.bucket:
|
||||||
|
os.system('gsutil cp gs://%s/evolve.txt .' % opt.bucket) # download evolve.txt if exists
|
||||||
|
|
||||||
for _ in range(1000): # generations to evolve
|
for _ in range(1000): # generations to evolve
|
||||||
|
if os._exists('evolve.txt'): # if evolve.txt exists: select best hyps and mutate
|
||||||
# Get best hyperparameters
|
# Get best hyperparameters
|
||||||
x = np.loadtxt('evolve.txt', ndmin=2)
|
x = np.loadtxt('evolve.txt', ndmin=2)
|
||||||
x = x[fitness(x).argmax()] # select best fitness hyps
|
x = x[fitness(x).argmax()] # select best fitness hyps
|
||||||
|
|
|
@ -37,14 +37,12 @@ def fuse_conv_and_bn(conv, bn):
|
||||||
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# init
|
# init
|
||||||
fusedconv = torch.nn.Conv2d(
|
fusedconv = torch.nn.Conv2d(conv.in_channels,
|
||||||
conv.in_channels,
|
|
||||||
conv.out_channels,
|
conv.out_channels,
|
||||||
kernel_size=conv.kernel_size,
|
kernel_size=conv.kernel_size,
|
||||||
stride=conv.stride,
|
stride=conv.stride,
|
||||||
padding=conv.padding,
|
padding=conv.padding,
|
||||||
bias=True
|
bias=True)
|
||||||
)
|
|
||||||
|
|
||||||
# prepare filters
|
# prepare filters
|
||||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||||
|
|
Loading…
Reference in New Issue