add support for standalone BatchNorm2d()
This commit is contained in:
parent
682c2b27e7
commit
eb9fb245aa
10
models.py
10
models.py
|
@ -9,16 +9,14 @@ def create_modules(module_defs, img_size):
|
||||||
# Constructs module list of layer blocks from module configuration in module_defs
|
# Constructs module list of layer blocks from module configuration in module_defs
|
||||||
|
|
||||||
img_size = [img_size] * 2 if isinstance(img_size, int) else img_size # expand if necessary
|
img_size = [img_size] * 2 if isinstance(img_size, int) else img_size # expand if necessary
|
||||||
hyperparams = module_defs.pop(0)
|
_ = module_defs.pop(0) # cfg training hyperparams (unused)
|
||||||
output_filters = [int(hyperparams['channels'])]
|
output_filters = [3] # input channels
|
||||||
module_list = nn.ModuleList()
|
module_list = nn.ModuleList()
|
||||||
routs = [] # list of layers which rout to deeper layers
|
routs = [] # list of layers which rout to deeper layers
|
||||||
yolo_index = -1
|
yolo_index = -1
|
||||||
|
|
||||||
for i, mdef in enumerate(module_defs):
|
for i, mdef in enumerate(module_defs):
|
||||||
modules = nn.Sequential()
|
modules = nn.Sequential()
|
||||||
# if i == 0:
|
|
||||||
# modules.add_module('BatchNorm2d_0', nn.BatchNorm2d(output_filters[-1], momentum=0.1))
|
|
||||||
|
|
||||||
if mdef['type'] == 'convolutional':
|
if mdef['type'] == 'convolutional':
|
||||||
bn = mdef['batch_normalize']
|
bn = mdef['batch_normalize']
|
||||||
|
@ -43,6 +41,10 @@ def create_modules(module_defs, img_size):
|
||||||
elif mdef['activation'] == 'swish':
|
elif mdef['activation'] == 'swish':
|
||||||
modules.add_module('activation', Swish())
|
modules.add_module('activation', Swish())
|
||||||
|
|
||||||
|
elif mdef['type'] == 'BatchNorm2d':
|
||||||
|
filters = output_filters[-1]
|
||||||
|
modules = nn.BatchNorm2d(filters, momentum=0.03, eps=1E-4)
|
||||||
|
|
||||||
elif mdef['type'] == 'maxpool':
|
elif mdef['type'] == 'maxpool':
|
||||||
size = mdef['size']
|
size = mdef['size']
|
||||||
stride = mdef['stride']
|
stride = mdef['stride']
|
||||||
|
|
Loading…
Reference in New Issue