k for kernel_size

This commit is contained in:
Glenn Jocher 2020-04-11 12:37:03 -07:00
parent a34219a54b
commit ed1d4f5ae7
1 changed files with 8 additions and 8 deletions

View File

@ -21,20 +21,20 @@ def create_modules(module_defs, img_size):
if mdef['type'] == 'convolutional': if mdef['type'] == 'convolutional':
bn = mdef['batch_normalize'] bn = mdef['batch_normalize']
filters = mdef['filters'] filters = mdef['filters']
size = mdef['size'] k = mdef['size'] # kernel size
stride = mdef['stride'] if 'stride' in mdef else (mdef['stride_y'], mdef['stride_x']) stride = mdef['stride'] if 'stride' in mdef else (mdef['stride_y'], mdef['stride_x'])
if isinstance(size, int): # single-size conv if isinstance(k, int): # single-size conv
modules.add_module('Conv2d', nn.Conv2d(in_channels=output_filters[-1], modules.add_module('Conv2d', nn.Conv2d(in_channels=output_filters[-1],
out_channels=filters, out_channels=filters,
kernel_size=size, kernel_size=k,
stride=stride, stride=stride,
padding=size // 2 if mdef['pad'] else 0, padding=k // 2 if mdef['pad'] else 0,
groups=mdef['groups'] if 'groups' in mdef else 1, groups=mdef['groups'] if 'groups' in mdef else 1,
bias=not bn)) bias=not bn))
else: # multiple-size conv else: # multiple-size conv
modules.add_module('MixConv2d', MixConv2d(in_ch=output_filters[-1], modules.add_module('MixConv2d', MixConv2d(in_ch=output_filters[-1],
out_ch=filters, out_ch=filters,
k=size, k=k,
stride=stride, stride=stride,
bias=not bn)) bias=not bn))
@ -58,10 +58,10 @@ def create_modules(module_defs, img_size):
modules.running_var = torch.tensor([0.0524, 0.0502, 0.0506]) modules.running_var = torch.tensor([0.0524, 0.0502, 0.0506])
elif mdef['type'] == 'maxpool': elif mdef['type'] == 'maxpool':
size = mdef['size'] k = mdef['size'] # kernel size
stride = mdef['stride'] stride = mdef['stride']
maxpool = nn.MaxPool2d(kernel_size=size, stride=stride, padding=(size - 1) // 2) maxpool = nn.MaxPool2d(kernel_size=k, stride=stride, padding=(k - 1) // 2)
if size == 2 and stride == 1: # yolov3-tiny if k == 2 and stride == 1: # yolov3-tiny
modules.add_module('ZeroPad2d', nn.ZeroPad2d((0, 1, 0, 1))) modules.add_module('ZeroPad2d', nn.ZeroPad2d((0, 1, 0, 1)))
modules.add_module('MaxPool2d', maxpool) modules.add_module('MaxPool2d', maxpool)
else: else: