add support for standalone BatchNorm2d()

This commit is contained in:
Glenn Jocher 2020-04-03 14:21:47 -07:00
parent 682c2b27e7
commit eb9fb245aa
1 changed files with 6 additions and 4 deletions

View File

@ -9,16 +9,14 @@ def create_modules(module_defs, img_size):
# Constructs module list of layer blocks from module configuration in module_defs # Constructs module list of layer blocks from module configuration in module_defs
img_size = [img_size] * 2 if isinstance(img_size, int) else img_size # expand if necessary img_size = [img_size] * 2 if isinstance(img_size, int) else img_size # expand if necessary
hyperparams = module_defs.pop(0) _ = module_defs.pop(0) # cfg training hyperparams (unused)
output_filters = [int(hyperparams['channels'])] output_filters = [3] # input channels
module_list = nn.ModuleList() module_list = nn.ModuleList()
routs = [] # list of layers which rout to deeper layers routs = [] # list of layers which rout to deeper layers
yolo_index = -1 yolo_index = -1
for i, mdef in enumerate(module_defs): for i, mdef in enumerate(module_defs):
modules = nn.Sequential() modules = nn.Sequential()
# if i == 0:
# modules.add_module('BatchNorm2d_0', nn.BatchNorm2d(output_filters[-1], momentum=0.1))
if mdef['type'] == 'convolutional': if mdef['type'] == 'convolutional':
bn = mdef['batch_normalize'] bn = mdef['batch_normalize']
@ -43,6 +41,10 @@ def create_modules(module_defs, img_size):
elif mdef['activation'] == 'swish': elif mdef['activation'] == 'swish':
modules.add_module('activation', Swish()) modules.add_module('activation', Swish())
elif mdef['type'] == 'BatchNorm2d':
filters = output_filters[-1]
modules = nn.BatchNorm2d(filters, momentum=0.03, eps=1E-4)
elif mdef['type'] == 'maxpool': elif mdef['type'] == 'maxpool':
size = mdef['size'] size = mdef['size']
stride = mdef['stride'] stride = mdef['stride']