add support for standalone BatchNorm2d()
This commit is contained in:
parent
682c2b27e7
commit
eb9fb245aa
10
models.py
10
models.py
|
@ -9,16 +9,14 @@ def create_modules(module_defs, img_size):
|
|||
# Constructs module list of layer blocks from module configuration in module_defs
|
||||
|
||||
img_size = [img_size] * 2 if isinstance(img_size, int) else img_size # expand if necessary
|
||||
hyperparams = module_defs.pop(0)
|
||||
output_filters = [int(hyperparams['channels'])]
|
||||
_ = module_defs.pop(0) # cfg training hyperparams (unused)
|
||||
output_filters = [3] # input channels
|
||||
module_list = nn.ModuleList()
|
||||
routs = [] # list of layers which rout to deeper layers
|
||||
yolo_index = -1
|
||||
|
||||
for i, mdef in enumerate(module_defs):
|
||||
modules = nn.Sequential()
|
||||
# if i == 0:
|
||||
# modules.add_module('BatchNorm2d_0', nn.BatchNorm2d(output_filters[-1], momentum=0.1))
|
||||
|
||||
if mdef['type'] == 'convolutional':
|
||||
bn = mdef['batch_normalize']
|
||||
|
@ -43,6 +41,10 @@ def create_modules(module_defs, img_size):
|
|||
elif mdef['activation'] == 'swish':
|
||||
modules.add_module('activation', Swish())
|
||||
|
||||
elif mdef['type'] == 'BatchNorm2d':
|
||||
filters = output_filters[-1]
|
||||
modules = nn.BatchNorm2d(filters, momentum=0.03, eps=1E-4)
|
||||
|
||||
elif mdef['type'] == 'maxpool':
|
||||
size = mdef['size']
|
||||
stride = mdef['stride']
|
||||
|
|
Loading…
Reference in New Issue