updates
This commit is contained in:
parent
05358accbb
commit
cf51cf9c99
53
train.py
53
train.py
|
@ -294,28 +294,20 @@ def print_mutation(hyp, results):
|
||||||
b = '%11.4g' * len(hyp) % tuple(hyp.values()) # hyperparam values
|
b = '%11.4g' * len(hyp) % tuple(hyp.values()) # hyperparam values
|
||||||
c = '%11.3g' * len(results) % results # results (P, R, mAP, F1, test_loss)
|
c = '%11.3g' * len(results) % results # results (P, R, mAP, F1, test_loss)
|
||||||
print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
|
print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
|
||||||
with open('evolve.txt', 'a') as f:
|
|
||||||
f.write(c + b + '\n')
|
|
||||||
|
|
||||||
cloud_evolve = False
|
if opt.cloud_evolve:
|
||||||
if cloud_evolve:
|
os.system('gsutil cp gs://yolov4/evolve.txt .') # download evolve.txt
|
||||||
# download cloud_evolve.txt
|
with open('evolve.txt', 'a') as f: # append result to evolve.txt
|
||||||
cloud_file = 'https://storage.googleapis.com/yolov4/cloud_evolve.txt'
|
f.write(c + b + '\n')
|
||||||
local_file = cloud_file.replace('https://', '')
|
os.system('gsutil cp evolve.txt gs://yolov4') # upload evolve.txt
|
||||||
name = Path(local_file).name
|
else:
|
||||||
google_utils.download_blob(bucket_name='yolov4', source_blob_name=name, destination_file_name=local_file)
|
with open('evolve.txt', 'a') as f:
|
||||||
|
|
||||||
# add result to local cloud_evolve.txt
|
|
||||||
with open(local_file, 'a') as f:
|
|
||||||
f.write(c + b + '\n')
|
f.write(c + b + '\n')
|
||||||
|
|
||||||
# upload cloud_evolve.txt
|
|
||||||
google_utils.upload_blob(bucket_name='yolov4', source_file_name=local_file, destination_blob_name=name)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
|
parser.add_argument('--epochs', type=int, default=1, help='number of epochs')
|
||||||
parser.add_argument('--batch-size', type=int, default=8, help='batch size')
|
parser.add_argument('--batch-size', type=int, default=8, help='batch size')
|
||||||
parser.add_argument('--accumulate', type=int, default=8, help='number of batches to accumulate before optimizing')
|
parser.add_argument('--accumulate', type=int, default=8, help='number of batches to accumulate before optimizing')
|
||||||
parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path')
|
parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path')
|
||||||
|
@ -329,10 +321,12 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
||||||
parser.add_argument('--giou', action='store_true', help='use GIoU loss instead of xy, wh loss')
|
parser.add_argument('--giou', action='store_true', help='use GIoU loss instead of xy, wh loss')
|
||||||
parser.add_argument('--evolve', action='store_true', help='run hyperparameter evolution')
|
parser.add_argument('--evolve', action='store_true', help='run hyperparameter evolution')
|
||||||
|
parser.add_argument('--cloud_evolve', action='store_true', help='--evolve from a central source')
|
||||||
parser.add_argument('--var', default=0, type=int, help='debug variable')
|
parser.add_argument('--var', default=0, type=int, help='debug variable')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
print(opt)
|
print(opt)
|
||||||
|
|
||||||
|
opt.evolve = opt.cloud_evolve or opt.evolve
|
||||||
if opt.evolve:
|
if opt.evolve:
|
||||||
opt.notest = True # only test final epoch
|
opt.notest = True # only test final epoch
|
||||||
opt.nosave = True # only save final checkpoint
|
opt.nosave = True # only save final checkpoint
|
||||||
|
@ -347,16 +341,17 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# Evolve hyperparameters (optional)
|
# Evolve hyperparameters (optional)
|
||||||
if opt.evolve:
|
if opt.evolve:
|
||||||
best_fitness = results[2] # use mAP for fitness
|
|
||||||
|
|
||||||
# Write mutation results
|
|
||||||
print_mutation(hyp, results)
|
|
||||||
|
|
||||||
gen = 1000 # generations to evolve
|
gen = 1000 # generations to evolve
|
||||||
for _ in range(gen):
|
print_mutation(hyp, results) # Write mutation results
|
||||||
|
|
||||||
# Mutate hyperparameters
|
for _ in range(gen):
|
||||||
old_hyp = hyp.copy()
|
# Get best hyperparamters
|
||||||
|
x = np.loadtxt('evolve.txt', ndmin=2)
|
||||||
|
x = x[x[:, 2].argmax()] # select best mAP for fitness (col 2)
|
||||||
|
for i, k in enumerate(hyp.keys()):
|
||||||
|
hyp[k] = x[i + 5]
|
||||||
|
|
||||||
|
# Mutate
|
||||||
init_seeds(seed=int(time.time()))
|
init_seeds(seed=int(time.time()))
|
||||||
s = [.2, .2, .2, .2, .2, .2, .2, .2, .2 * 0, .2 * 0, .05 * 0, .2 * 0] # fractional sigmas
|
s = [.2, .2, .2, .2, .2, .2, .2, .2, .2 * 0, .2 * 0, .05 * 0, .2 * 0] # fractional sigmas
|
||||||
for i, k in enumerate(hyp.keys()):
|
for i, k in enumerate(hyp.keys()):
|
||||||
|
@ -369,25 +364,17 @@ if __name__ == '__main__':
|
||||||
for k, v in zip(keys, limits):
|
for k, v in zip(keys, limits):
|
||||||
hyp[k] = np.clip(hyp[k], v[0], v[1])
|
hyp[k] = np.clip(hyp[k], v[0], v[1])
|
||||||
|
|
||||||
# Determine mutation fitness
|
# Train mutation
|
||||||
results = train(opt.cfg,
|
results = train(opt.cfg,
|
||||||
opt.data_cfg,
|
opt.data_cfg,
|
||||||
img_size=opt.img_size,
|
img_size=opt.img_size,
|
||||||
epochs=opt.epochs,
|
epochs=opt.epochs,
|
||||||
batch_size=opt.batch_size,
|
batch_size=opt.batch_size,
|
||||||
accumulate=opt.accumulate)
|
accumulate=opt.accumulate)
|
||||||
mutation_fitness = results[2]
|
|
||||||
|
|
||||||
# Write mutation results
|
# Write mutation results
|
||||||
print_mutation(hyp, results)
|
print_mutation(hyp, results)
|
||||||
|
|
||||||
# Update hyperparameters if fitness improved
|
|
||||||
if mutation_fitness > best_fitness:
|
|
||||||
print('Fitness improved!')
|
|
||||||
best_fitness = mutation_fitness
|
|
||||||
else:
|
|
||||||
hyp = old_hyp.copy() # reset hyp to
|
|
||||||
|
|
||||||
# # Plot results
|
# # Plot results
|
||||||
# import numpy as np
|
# import numpy as np
|
||||||
# import matplotlib.pyplot as plt
|
# import matplotlib.pyplot as plt
|
||||||
|
|
Loading…
Reference in New Issue