This commit is contained in:
Glenn Jocher 2018-12-16 15:16:19 +01:00
parent b52a49cf12
commit 18ccd184bf
1 changed files with 8 additions and 5 deletions

View File

@ -22,6 +22,7 @@ def train(
resume=False,
epochs=100,
batch_size=16,
accumulated_batches=1,
weights_path='weights',
report=False,
multi_scale=False,
@ -153,8 +154,8 @@ def train(
loss = model(imgs.to(device), targets, batch_report=report, var=var)
loss.backward()
# accumulated_batches = 1 # accumulate gradient for 4 batches before optimizing
# if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
# accumulate gradient for x batches before optimizing
if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1):
optimizer.step()
optimizer.zero_grad()
@ -237,6 +238,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
parser.add_argument('--batch-size', type=int, default=16, help='size of each image batch')
parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step')
parser.add_argument('--data-config', type=str, default='cfg/coco.data', help='path to data config file')
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
parser.add_argument('--multi-scale', action='store_true', help='random image sizes per batch 320 - 608')
@ -244,7 +246,7 @@ if __name__ == '__main__':
parser.add_argument('--weights-path', type=str, default='weights', help='path to store weights')
parser.add_argument('--resume', action='store_true', help='resume training flag')
parser.add_argument('--report', action='store_true', help='report TP, FP, FN, P and R per batch (slower)')
parser.add_argument('--freeze', action='store_true', help='freeze darknet53.conv.74 layers for first epoche')
parser.add_argument('--freeze', action='store_true', help='freeze darknet53.conv.74 layers for first epoch')
parser.add_argument('--var', type=float, default=0, help='optional test variable')
opt = parser.parse_args()
print(opt, end='\n\n')
@ -259,6 +261,7 @@ if __name__ == '__main__':
resume=opt.resume,
epochs=opt.epochs,
batch_size=opt.batch_size,
accumulated_batches=opt.accumulated_batches,
weights_path=opt.weights_path,
report=opt.report,
multi_scale=opt.multi_scale,