updates
This commit is contained in:
parent
05358accbb
commit
cf51cf9c99
53
train.py
53
train.py
|
@ -294,28 +294,20 @@ def print_mutation(hyp, results):
|
|||
b = '%11.4g' * len(hyp) % tuple(hyp.values()) # hyperparam values
|
||||
c = '%11.3g' * len(results) % results # results (P, R, mAP, F1, test_loss)
|
||||
print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))
|
||||
|
||||
if opt.cloud_evolve:
|
||||
os.system('gsutil cp gs://yolov4/evolve.txt .') # download evolve.txt
|
||||
with open('evolve.txt', 'a') as f: # append result to evolve.txt
|
||||
f.write(c + b + '\n')
|
||||
os.system('gsutil cp evolve.txt gs://yolov4') # upload evolve.txt
|
||||
else:
|
||||
with open('evolve.txt', 'a') as f:
|
||||
f.write(c + b + '\n')
|
||||
|
||||
cloud_evolve = False
|
||||
if cloud_evolve:
|
||||
# download cloud_evolve.txt
|
||||
cloud_file = 'https://storage.googleapis.com/yolov4/cloud_evolve.txt'
|
||||
local_file = cloud_file.replace('https://', '')
|
||||
name = Path(local_file).name
|
||||
google_utils.download_blob(bucket_name='yolov4', source_blob_name=name, destination_file_name=local_file)
|
||||
|
||||
# add result to local cloud_evolve.txt
|
||||
with open(local_file, 'a') as f:
|
||||
f.write(c + b + '\n')
|
||||
|
||||
# upload cloud_evolve.txt
|
||||
google_utils.upload_blob(bucket_name='yolov4', source_file_name=local_file, destination_blob_name=name)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
|
||||
parser.add_argument('--epochs', type=int, default=1, help='number of epochs')
|
||||
parser.add_argument('--batch-size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('--accumulate', type=int, default=8, help='number of batches to accumulate before optimizing')
|
||||
parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path')
|
||||
|
@ -329,10 +321,12 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
||||
parser.add_argument('--giou', action='store_true', help='use GIoU loss instead of xy, wh loss')
|
||||
parser.add_argument('--evolve', action='store_true', help='run hyperparameter evolution')
|
||||
parser.add_argument('--cloud_evolve', action='store_true', help='--evolve from a central source')
|
||||
parser.add_argument('--var', default=0, type=int, help='debug variable')
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
|
||||
opt.evolve = opt.cloud_evolve or opt.evolve
|
||||
if opt.evolve:
|
||||
opt.notest = True # only test final epoch
|
||||
opt.nosave = True # only save final checkpoint
|
||||
|
@ -347,16 +341,17 @@ if __name__ == '__main__':
|
|||
|
||||
# Evolve hyperparameters (optional)
|
||||
if opt.evolve:
|
||||
best_fitness = results[2] # use mAP for fitness
|
||||
|
||||
# Write mutation results
|
||||
print_mutation(hyp, results)
|
||||
|
||||
gen = 1000 # generations to evolve
|
||||
for _ in range(gen):
|
||||
print_mutation(hyp, results) # Write mutation results
|
||||
|
||||
# Mutate hyperparameters
|
||||
old_hyp = hyp.copy()
|
||||
for _ in range(gen):
|
||||
# Get best hyperparamters
|
||||
x = np.loadtxt('evolve.txt', ndmin=2)
|
||||
x = x[x[:, 2].argmax()] # select best mAP for fitness (col 2)
|
||||
for i, k in enumerate(hyp.keys()):
|
||||
hyp[k] = x[i + 5]
|
||||
|
||||
# Mutate
|
||||
init_seeds(seed=int(time.time()))
|
||||
s = [.2, .2, .2, .2, .2, .2, .2, .2, .2 * 0, .2 * 0, .05 * 0, .2 * 0] # fractional sigmas
|
||||
for i, k in enumerate(hyp.keys()):
|
||||
|
@ -369,25 +364,17 @@ if __name__ == '__main__':
|
|||
for k, v in zip(keys, limits):
|
||||
hyp[k] = np.clip(hyp[k], v[0], v[1])
|
||||
|
||||
# Determine mutation fitness
|
||||
# Train mutation
|
||||
results = train(opt.cfg,
|
||||
opt.data_cfg,
|
||||
img_size=opt.img_size,
|
||||
epochs=opt.epochs,
|
||||
batch_size=opt.batch_size,
|
||||
accumulate=opt.accumulate)
|
||||
mutation_fitness = results[2]
|
||||
|
||||
# Write mutation results
|
||||
print_mutation(hyp, results)
|
||||
|
||||
# Update hyperparameters if fitness improved
|
||||
if mutation_fitness > best_fitness:
|
||||
print('Fitness improved!')
|
||||
best_fitness = mutation_fitness
|
||||
else:
|
||||
hyp = old_hyp.copy() # reset hyp to
|
||||
|
||||
# # Plot results
|
||||
# import numpy as np
|
||||
# import matplotlib.pyplot as plt
|
||||
|
|
Loading…
Reference in New Issue