From 4a4668224ba66a2e7a21445e42481e3d44111832 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 20 Apr 2019 22:46:23 +0200 Subject: [PATCH] Fuse Conv2d + BatchNorm2d --- detect.py | 17 ++--------------- models.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/detect.py b/detect.py index dc2a3dd3..f82c4bc7 100644 --- a/detect.py +++ b/detect.py @@ -34,21 +34,8 @@ def detect( else: # darknet format _ = load_darknet_weights(model, weights) - # Fuse batchnorm - fuse = True - if fuse: - fused_list = nn.ModuleList() - for a in list(model.children())[0]: - for i, b in enumerate(a): - if isinstance(b, nn.modules.batchnorm.BatchNorm2d): - # fuse this bn layer with the previous conv2d layer - conv = a[i - 1] - fused = torch_utils.fuse_conv_and_bn(conv, b) - a = nn.Sequential(fused, *list(a.children())[i + 1:]) - break - fused_list.append(a) - model.module_list = fused_list - #model_info(model) # yolov3-spp reduced from 225 to 152 layers + # Fuse Conv2d + BatchNorm2d layers + model.fuse() model.to(device).eval() diff --git a/models.py b/models.py index 77abdd9b..e5091b9a 100755 --- a/models.py +++ b/models.py @@ -212,6 +212,21 @@ class Darknet(nn.Module): io, p = list(zip(*output)) # inference output, training output return torch.cat(io, 1), p + def fuse(self): + # Fuse Conv2d + BatchNorm2d layers throughout model + fused_list = nn.ModuleList() + for a in list(self.children())[0]: + for i, b in enumerate(a): + if isinstance(b, nn.modules.batchnorm.BatchNorm2d): + # fuse this bn layer with the previous conv2d layer + conv = a[i - 1] + fused = torch_utils.fuse_conv_and_bn(conv, b) + a = nn.Sequential(fused, *list(a.children())[i + 1:]) + break + fused_list.append(a) + self.module_list = fused_list + # model_info(self) # yolov3-spp reduced from 225 to 152 layers + def get_yolo_layers(model): a = [module_def['type'] == 'yolo' for module_def in model.module_defs]