This commit is contained in:
Glenn Jocher 2019-08-23 13:25:27 +02:00
parent 0d71fd8228
commit fd653eca8a
1 changed files with 16 additions and 24 deletions

View File

@ -75,12 +75,14 @@ hyp = {'giou': 1.582, # giou loss gain
# 'shear': 0.434} # image shear (+/- deg) # 'shear': 0.434} # image shear (+/- deg)
def train(cfg, def train():
data, cfg = opt.cfg
img_size=416, data = opt.data
epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs img_size = opt.img_size
batch_size=16, epochs = opt.epochs # 500200 batches at bs 16, 117263 images = 273 epochs
accumulate=4): # effective bs = batch_size * accumulate = 16 * 4 = 64 batch_size = opt.batch_size
accumulate = opt.accumulate # effective bs = batch_size * accumulate = 16 * 4 = 64
# Initialize # Initialize
init_seeds() init_seeds()
weights = 'weights' + os.sep weights = 'weights' + os.sep
@ -359,16 +361,16 @@ def train(cfg,
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=273, help='number of epochs') parser.add_argument('--epochs', type=int, default=273) # 500200 batches at bs 16, 117263 images = 273 epochs
parser.add_argument('--batch-size', type=int, default=32, help='batch size') parser.add_argument('--batch-size', type=int, default=32) # effective bs = batch_size * accumulate = 16 * 4 = 64
parser.add_argument('--accumulate', type=int, default=2, help='number of batches to accumulate before optimizing') parser.add_argument('--accumulate', type=int, default=2, help='batches to accumulate before optimizing')
parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path') parser.add_argument('--cfg', type=str, default='cfg/yolov3-spp.cfg', help='cfg file path')
parser.add_argument('--data', type=str, default='data/coco.data', help='coco.data file path') parser.add_argument('--data', type=str, default='data/coco.data', help='*.data file path')
parser.add_argument('--multi-scale', action='store_true', help='train at (1/1.5)x - 1.5x sizes') parser.add_argument('--multi-scale', action='store_true', help='train at (1/1.5)x - 1.5x sizes')
parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)') parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
parser.add_argument('--rect', action='store_true', help='rectangular training') parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--resume', action='store_true', help='resume training flag') parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
parser.add_argument('--transfer', action='store_true', help='transfer learning flag') parser.add_argument('--transfer', action='store_true', help='transfer learning')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
parser.add_argument('--notest', action='store_true', help='only test final epoch') parser.add_argument('--notest', action='store_true', help='only test final epoch')
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
@ -388,12 +390,7 @@ if __name__ == '__main__':
except: except:
pass pass
results = train(opt.cfg, results = train()
opt.data,
img_size=opt.img_size,
epochs=opt.epochs,
batch_size=opt.batch_size,
accumulate=opt.accumulate)
else: # Evolve hyperparameters (optional) else: # Evolve hyperparameters (optional)
opt.notest = True # only test final epoch opt.notest = True # only test final epoch
@ -423,12 +420,7 @@ if __name__ == '__main__':
hyp[k] = np.clip(hyp[k], v[0], v[1]) hyp[k] = np.clip(hyp[k], v[0], v[1])
# Train mutation # Train mutation
results = train(opt.cfg, results = train()
opt.data,
img_size=opt.img_size,
epochs=opt.epochs,
batch_size=opt.batch_size,
accumulate=opt.accumulate)
# Write mutation results # Write mutation results
print_mutation(hyp, results, opt.bucket) print_mutation(hyp, results, opt.bucket)