This commit is contained in:
Glenn Jocher 2019-08-03 14:38:06 +02:00
parent cd1f1eeecc
commit 2d8311a83f
1 changed files with 13 additions and 11 deletions

View File

@ -24,25 +24,27 @@ def create_modules(module_defs):
filters = int(module_def['filters']) filters = int(module_def['filters'])
kernel_size = int(module_def['size']) kernel_size = int(module_def['size'])
pad = (kernel_size - 1) // 2 if int(module_def['pad']) else 0 pad = (kernel_size - 1) // 2 if int(module_def['pad']) else 0
modules.add_module('conv_%d' % i, nn.Conv2d(in_channels=output_filters[-1], modules.add_module('Conv2d', nn.Conv2d(in_channels=output_filters[-1],
out_channels=filters, out_channels=filters,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=int(module_def['stride']), stride=int(module_def['stride']),
padding=pad, padding=pad,
bias=not bn)) bias=not bn))
if bn: if bn:
modules.add_module('batch_norm_%d' % i, nn.BatchNorm2d(filters)) modules.add_module('BatchNorm2d', nn.BatchNorm2d(filters))
if module_def['activation'] == 'leaky': if module_def['activation'] == 'leaky':
# modules.add_module('leaky_%d' % i, nn.PReLU(num_parameters=filters, init=0.10)) # modules.add_module('activation', nn.PReLU(num_parameters=filters, init=0.1))
modules.add_module('leaky_%d' % i, nn.LeakyReLU(0.1, inplace=True)) modules.add_module('activation', nn.LeakyReLU(0.1, inplace=True))
elif module_def['type'] == 'maxpool': elif module_def['type'] == 'maxpool':
kernel_size = int(module_def['size']) kernel_size = int(module_def['size'])
stride = int(module_def['stride']) stride = int(module_def['stride'])
maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=int((kernel_size - 1) // 2)) maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=int((kernel_size - 1) // 2))
if kernel_size == 2 and stride == 1: # yolov3-tiny if kernel_size == 2 and stride == 1: # yolov3-tiny
modules.add_module('_debug_padding_%d' % i, nn.ZeroPad2d((0, 1, 0, 1))) modules.add_module('ZeroPad2d', nn.ZeroPad2d((0, 1, 0, 1)))
modules.add_module('maxpool_%d' % i, maxpool) modules.add_module('MaxPool2d', maxpool)
else:
modules = maxpool
elif module_def['type'] == 'upsample': elif module_def['type'] == 'upsample':
modules = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest') modules = nn.Upsample(scale_factor=int(module_def['stride']), mode='nearest')