From cf51cf9c990e2ff6a3435c5910a7a291cb290138 Mon Sep 17 00:00:00 2001 From: glenn-jocher Date: Mon, 1 Jul 2019 17:14:42 +0200 Subject: [PATCH] updates --- train.py | 53 ++++++++++++++++++++--------------------------------- 1 file changed, 20 insertions(+), 33 deletions(-) diff --git a/train.py b/train.py index b911bf27..76969494 100644 --- a/train.py +++ b/train.py @@ -294,28 +294,20 @@ def print_mutation(hyp, results): b = '%11.4g' * len(hyp) % tuple(hyp.values()) # hyperparam values c = '%11.3g' * len(results) % results # results (P, R, mAP, F1, test_loss) print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c)) - with open('evolve.txt', 'a') as f: - f.write(c + b + '\n') - cloud_evolve = False - if cloud_evolve: - # download cloud_evolve.txt - cloud_file = 'https://storage.googleapis.com/yolov4/cloud_evolve.txt' - local_file = cloud_file.replace('https://', '') - name = Path(local_file).name - google_utils.download_blob(bucket_name='yolov4', source_blob_name=name, destination_file_name=local_file) - - # add result to local cloud_evolve.txt - with open(local_file, 'a') as f: + if opt.cloud_evolve: + os.system('gsutil cp gs://yolov4/evolve.txt .') # download evolve.txt + with open('evolve.txt', 'a') as f: # append result to evolve.txt + f.write(c + b + '\n') + os.system('gsutil cp evolve.txt gs://yolov4') # upload evolve.txt + else: + with open('evolve.txt', 'a') as f: f.write(c + b + '\n') - - # upload cloud_evolve.txt - google_utils.upload_blob(bucket_name='yolov4', source_file_name=local_file, destination_blob_name=name) if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--epochs', type=int, default=100, help='number of epochs') + parser.add_argument('--epochs', type=int, default=1, help='number of epochs') parser.add_argument('--batch-size', type=int, default=8, help='batch size') parser.add_argument('--accumulate', type=int, default=8, help='number of batches to accumulate before optimizing') parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path') @@ -329,10 +321,12 @@ if __name__ == '__main__': parser.add_argument('--notest', action='store_true', help='only test final epoch') parser.add_argument('--giou', action='store_true', help='use GIoU loss instead of xy, wh loss') parser.add_argument('--evolve', action='store_true', help='run hyperparameter evolution') + parser.add_argument('--cloud_evolve', action='store_true', help='--evolve from a central source') parser.add_argument('--var', default=0, type=int, help='debug variable') opt = parser.parse_args() print(opt) + opt.evolve = opt.cloud_evolve or opt.evolve if opt.evolve: opt.notest = True # only test final epoch opt.nosave = True # only save final checkpoint @@ -347,16 +341,17 @@ if __name__ == '__main__': # Evolve hyperparameters (optional) if opt.evolve: - best_fitness = results[2] # use mAP for fitness - - # Write mutation results - print_mutation(hyp, results) - gen = 1000 # generations to evolve - for _ in range(gen): + print_mutation(hyp, results) # Write mutation results - # Mutate hyperparameters - old_hyp = hyp.copy() + for _ in range(gen): + # Get best hyperparamters + x = np.loadtxt('evolve.txt', ndmin=2) + x = x[x[:, 2].argmax()] # select best mAP for fitness (col 2) + for i, k in enumerate(hyp.keys()): + hyp[k] = x[i + 5] + + # Mutate init_seeds(seed=int(time.time())) s = [.2, .2, .2, .2, .2, .2, .2, .2, .2 * 0, .2 * 0, .05 * 0, .2 * 0] # fractional sigmas for i, k in enumerate(hyp.keys()): @@ -369,25 +364,17 @@ if __name__ == '__main__': for k, v in zip(keys, limits): hyp[k] = np.clip(hyp[k], v[0], v[1]) - # Determine mutation fitness + # Train mutation results = train(opt.cfg, opt.data_cfg, img_size=opt.img_size, epochs=opt.epochs, batch_size=opt.batch_size, accumulate=opt.accumulate) - mutation_fitness = results[2] # Write mutation results print_mutation(hyp, results) - # Update hyperparameters if fitness improved - if mutation_fitness > best_fitness: - print('Fitness improved!') - best_fitness = mutation_fitness - else: - hyp = old_hyp.copy() # reset hyp to - # # Plot results # import numpy as np # import matplotlib.pyplot as plt