Fix fuse in models.py
This commit is contained in:
晨太狼 2019-08-09 18:44:47 +08:00 committed by Glenn Jocher
parent fdd5afa229
commit d41b444a15
1 changed files with 8 additions and 7 deletions

View File

@ -213,13 +213,14 @@ class Darknet(nn.Module):
# Fuse Conv2d + BatchNorm2d layers throughout model # Fuse Conv2d + BatchNorm2d layers throughout model
fused_list = nn.ModuleList() fused_list = nn.ModuleList()
for a in list(self.children())[0]: for a in list(self.children())[0]:
for i, b in enumerate(a): if isinstance(a, nn.Sequential):
if isinstance(b, nn.modules.batchnorm.BatchNorm2d): for i, b in enumerate(a):
# fuse this bn layer with the previous conv2d layer if isinstance(b, nn.modules.batchnorm.BatchNorm2d):
conv = a[i - 1] # fuse this bn layer with the previous conv2d layer
fused = torch_utils.fuse_conv_and_bn(conv, b) conv = a[i - 1]
a = nn.Sequential(fused, *list(a.children())[i + 1:]) fused = torch_utils.fuse_conv_and_bn(conv, b)
break a = nn.Sequential(fused, *list(a.children())[i + 1:])
break
fused_list.append(a) fused_list.append(a)
self.module_list = fused_list self.module_list = fused_list
# model_info(self) # yolov3-spp reduced from 225 to 152 layers # model_info(self) # yolov3-spp reduced from 225 to 152 layers