This commit is contained in:
Glenn Jocher 2019-11-16 13:12:56 -08:00
parent 84cb744761
commit dc82956aff
1 changed files with 2 additions and 1 deletions

View File

@ -23,11 +23,12 @@ def create_modules(module_defs, img_size, arc):
bn = int(mdef['batch_normalize']) bn = int(mdef['batch_normalize'])
filters = int(mdef['filters']) filters = int(mdef['filters'])
kernel_size = int(mdef['size']) kernel_size = int(mdef['size'])
stride = int(mdef['stride']) if 'stride' in mdef else (int(mdef['stride_y']), int(mdef['stride_x']))
pad = (kernel_size - 1) // 2 if int(mdef['pad']) else 0 pad = (kernel_size - 1) // 2 if int(mdef['pad']) else 0
modules.add_module('Conv2d', nn.Conv2d(in_channels=output_filters[-1], modules.add_module('Conv2d', nn.Conv2d(in_channels=output_filters[-1],
out_channels=filters, out_channels=filters,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=int(mdef['stride']), stride=stride,
padding=pad, padding=pad,
bias=not bn)) bias=not bn))
if bn: if bn: