This commit is contained in:
Glenn Jocher 2019-12-09 13:17:30 -08:00
parent 2391996474
commit 07c1fafba8
1 changed files with 6 additions and 6 deletions

View File

@ -22,12 +22,12 @@ def create_modules(module_defs, img_size, arc):
if mdef['type'] == 'convolutional': if mdef['type'] == 'convolutional':
bn = int(mdef['batch_normalize']) bn = int(mdef['batch_normalize'])
filters = int(mdef['filters']) filters = int(mdef['filters'])
kernel_size = int(mdef['size']) size = int(mdef['size'])
stride = int(mdef['stride']) if 'stride' in mdef else (int(mdef['stride_y']), int(mdef['stride_x'])) stride = int(mdef['stride']) if 'stride' in mdef else (int(mdef['stride_y']), int(mdef['stride_x']))
pad = (kernel_size - 1) // 2 if int(mdef['pad']) else 0 pad = (size - 1) // 2 if int(mdef['pad']) else 0
modules.add_module('Conv2d', nn.Conv2d(in_channels=output_filters[-1], modules.add_module('Conv2d', nn.Conv2d(in_channels=output_filters[-1],
out_channels=filters, out_channels=filters,
kernel_size=kernel_size, kernel_size=size,
stride=stride, stride=stride,
padding=pad, padding=pad,
bias=not bn)) bias=not bn))
@ -40,10 +40,10 @@ def create_modules(module_defs, img_size, arc):
modules.add_module('activation', Swish()) modules.add_module('activation', Swish())
elif mdef['type'] == 'maxpool': elif mdef['type'] == 'maxpool':
kernel_size = int(mdef['size']) size = int(mdef['size'])
stride = int(mdef['stride']) stride = int(mdef['stride'])
maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=int((kernel_size - 1) // 2)) maxpool = nn.MaxPool2d(kernel_size=size, stride=stride, padding=int((size - 1) // 2))
if kernel_size == 2 and stride == 1: # yolov3-tiny if size == 2 and stride == 1: # yolov3-tiny
modules.add_module('ZeroPad2d', nn.ZeroPad2d((0, 1, 0, 1))) modules.add_module('ZeroPad2d', nn.ZeroPad2d((0, 1, 0, 1)))
modules.add_module('MaxPool2d', maxpool) modules.add_module('MaxPool2d', maxpool)
else: else: