diff --git a/train.py b/train.py index d67a6296..52648fbb 100644 --- a/train.py +++ b/train.py @@ -22,6 +22,7 @@ def train( resume=False, epochs=100, batch_size=16, + accumulated_batches=1, weights_path='weights', report=False, multi_scale=False, @@ -153,10 +154,10 @@ def train( loss = model(imgs.to(device), targets, batch_report=report, var=var) loss.backward() - # accumulated_batches = 1 # accumulate gradient for 4 batches before optimizing - # if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1): - optimizer.step() - optimizer.zero_grad() + # accumulate gradient for x batches before optimizing + if ((i + 1) % accumulated_batches == 0) or (i == len(dataloader) - 1): + optimizer.step() + optimizer.zero_grad() # Running epoch-means of tracked metrics ui += 1 @@ -237,6 +238,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=100, help='number of epochs') parser.add_argument('--batch-size', type=int, default=16, help='size of each image batch') + parser.add_argument('--accumulated-batches', type=int, default=1, help='number of batches before optimizer step') parser.add_argument('--data-config', type=str, default='cfg/coco.data', help='path to data config file') parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path') parser.add_argument('--multi-scale', action='store_true', help='random image sizes per batch 320 - 608') @@ -244,7 +246,7 @@ if __name__ == '__main__': parser.add_argument('--weights-path', type=str, default='weights', help='path to store weights') parser.add_argument('--resume', action='store_true', help='resume training flag') parser.add_argument('--report', action='store_true', help='report TP, FP, FN, P and R per batch (slower)') - parser.add_argument('--freeze', action='store_true', help='freeze darknet53.conv.74 layers for first epoche') + parser.add_argument('--freeze', action='store_true', help='freeze darknet53.conv.74 layers for first epoch') parser.add_argument('--var', type=float, default=0, help='optional test variable') opt = parser.parse_args() print(opt, end='\n\n') @@ -259,6 +261,7 @@ if __name__ == '__main__': resume=opt.resume, epochs=opt.epochs, batch_size=opt.batch_size, + accumulated_batches=opt.accumulated_batches, weights_path=opt.weights_path, report=opt.report, multi_scale=opt.multi_scale,