diff --git a/models.py b/models.py index d83ef0d5..d352e873 100755 --- a/models.py +++ b/models.py @@ -23,11 +23,12 @@ def create_modules(module_defs, img_size, arc): bn = int(mdef['batch_normalize']) filters = int(mdef['filters']) kernel_size = int(mdef['size']) + stride = int(mdef['stride']) if 'stride' in mdef else (int(mdef['stride_y']), int(mdef['stride_x'])) pad = (kernel_size - 1) // 2 if int(mdef['pad']) else 0 modules.add_module('Conv2d', nn.Conv2d(in_channels=output_filters[-1], out_channels=filters, kernel_size=kernel_size, - stride=int(mdef['stride']), + stride=stride, padding=pad, bias=not bn)) if bn: