Fuse Conv2d + BatchNorm2d

This commit is contained in:
Glenn Jocher 2019-04-20 22:46:23 +02:00
parent f9d25f6d24
commit 4a4668224b
2 changed files with 17 additions and 15 deletions

View File

@ -34,21 +34,8 @@ def detect(
else: # darknet format
_ = load_darknet_weights(model, weights)
# Fuse batchnorm
fuse = True
if fuse:
fused_list = nn.ModuleList()
for a in list(model.children())[0]:
for i, b in enumerate(a):
if isinstance(b, nn.modules.batchnorm.BatchNorm2d):
# fuse this bn layer with the previous conv2d layer
conv = a[i - 1]
fused = torch_utils.fuse_conv_and_bn(conv, b)
a = nn.Sequential(fused, *list(a.children())[i + 1:])
break
fused_list.append(a)
model.module_list = fused_list
#model_info(model) # yolov3-spp reduced from 225 to 152 layers
# Fuse Conv2d + BatchNorm2d layers
model.fuse()
model.to(device).eval()

View File

@ -212,6 +212,21 @@ class Darknet(nn.Module):
io, p = list(zip(*output)) # inference output, training output
return torch.cat(io, 1), p
def fuse(self):
# Fuse Conv2d + BatchNorm2d layers throughout model
fused_list = nn.ModuleList()
for a in list(self.children())[0]:
for i, b in enumerate(a):
if isinstance(b, nn.modules.batchnorm.BatchNorm2d):
# fuse this bn layer with the previous conv2d layer
conv = a[i - 1]
fused = torch_utils.fuse_conv_and_bn(conv, b)
a = nn.Sequential(fused, *list(a.children())[i + 1:])
break
fused_list.append(a)
self.module_list = fused_list
# model_info(self) # yolov3-spp reduced from 225 to 152 layers
def get_yolo_layers(model):
a = [module_def['type'] == 'yolo' for module_def in model.module_defs]