updates
This commit is contained in:
parent
2391996474
commit
07c1fafba8
12
models.py
12
models.py
|
@ -22,12 +22,12 @@ def create_modules(module_defs, img_size, arc):
|
||||||
if mdef['type'] == 'convolutional':
|
if mdef['type'] == 'convolutional':
|
||||||
bn = int(mdef['batch_normalize'])
|
bn = int(mdef['batch_normalize'])
|
||||||
filters = int(mdef['filters'])
|
filters = int(mdef['filters'])
|
||||||
kernel_size = int(mdef['size'])
|
size = int(mdef['size'])
|
||||||
stride = int(mdef['stride']) if 'stride' in mdef else (int(mdef['stride_y']), int(mdef['stride_x']))
|
stride = int(mdef['stride']) if 'stride' in mdef else (int(mdef['stride_y']), int(mdef['stride_x']))
|
||||||
pad = (kernel_size - 1) // 2 if int(mdef['pad']) else 0
|
pad = (size - 1) // 2 if int(mdef['pad']) else 0
|
||||||
modules.add_module('Conv2d', nn.Conv2d(in_channels=output_filters[-1],
|
modules.add_module('Conv2d', nn.Conv2d(in_channels=output_filters[-1],
|
||||||
out_channels=filters,
|
out_channels=filters,
|
||||||
kernel_size=kernel_size,
|
kernel_size=size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding=pad,
|
padding=pad,
|
||||||
bias=not bn))
|
bias=not bn))
|
||||||
|
@ -40,10 +40,10 @@ def create_modules(module_defs, img_size, arc):
|
||||||
modules.add_module('activation', Swish())
|
modules.add_module('activation', Swish())
|
||||||
|
|
||||||
elif mdef['type'] == 'maxpool':
|
elif mdef['type'] == 'maxpool':
|
||||||
kernel_size = int(mdef['size'])
|
size = int(mdef['size'])
|
||||||
stride = int(mdef['stride'])
|
stride = int(mdef['stride'])
|
||||||
maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=int((kernel_size - 1) // 2))
|
maxpool = nn.MaxPool2d(kernel_size=size, stride=stride, padding=int((size - 1) // 2))
|
||||||
if kernel_size == 2 and stride == 1: # yolov3-tiny
|
if size == 2 and stride == 1: # yolov3-tiny
|
||||||
modules.add_module('ZeroPad2d', nn.ZeroPad2d((0, 1, 0, 1)))
|
modules.add_module('ZeroPad2d', nn.ZeroPad2d((0, 1, 0, 1)))
|
||||||
modules.add_module('MaxPool2d', maxpool)
|
modules.add_module('MaxPool2d', maxpool)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue