updates
This commit is contained in:
parent
b202baa31c
commit
eeae43c414
|
@ -20,7 +20,7 @@ def create_modules(module_defs):
|
||||||
modules = nn.Sequential()
|
modules = nn.Sequential()
|
||||||
|
|
||||||
if module_def['type'] == 'convolutional':
|
if module_def['type'] == 'convolutional':
|
||||||
bn = int(module_def['batch_normalize']) if 'batch_normalize' in module_def else 0
|
bn = int(module_def['batch_normalize'])
|
||||||
filters = int(module_def['filters'])
|
filters = int(module_def['filters'])
|
||||||
kernel_size = int(module_def['size'])
|
kernel_size = int(module_def['size'])
|
||||||
pad = (kernel_size - 1) // 2 if int(module_def['pad']) else 0
|
pad = (kernel_size - 1) // 2 if int(module_def['pad']) else 0
|
||||||
|
|
|
@ -10,7 +10,7 @@ def parse_model_cfg(path):
|
||||||
module_defs.append({})
|
module_defs.append({})
|
||||||
module_defs[-1]['type'] = line[1:-1].rstrip()
|
module_defs[-1]['type'] = line[1:-1].rstrip()
|
||||||
if module_defs[-1]['type'] == 'convolutional':
|
if module_defs[-1]['type'] == 'convolutional':
|
||||||
module_defs[-1]['batch_normalize'] = 0
|
module_defs[-1]['batch_normalize'] = 0 # pre-populate with zeros (may be overwritten later)
|
||||||
else:
|
else:
|
||||||
key, value = line.split("=")
|
key, value = line.split("=")
|
||||||
value = value.strip()
|
value = value.strip()
|
||||||
|
|
Loading…
Reference in New Issue