Fuse Conv2d + BatchNorm2d
This commit is contained in:
parent
f9d25f6d24
commit
4a4668224b
17
detect.py
17
detect.py
|
@ -34,21 +34,8 @@ def detect(
|
||||||
else: # darknet format
|
else: # darknet format
|
||||||
_ = load_darknet_weights(model, weights)
|
_ = load_darknet_weights(model, weights)
|
||||||
|
|
||||||
# Fuse batchnorm
|
# Fuse Conv2d + BatchNorm2d layers
|
||||||
fuse = True
|
model.fuse()
|
||||||
if fuse:
|
|
||||||
fused_list = nn.ModuleList()
|
|
||||||
for a in list(model.children())[0]:
|
|
||||||
for i, b in enumerate(a):
|
|
||||||
if isinstance(b, nn.modules.batchnorm.BatchNorm2d):
|
|
||||||
# fuse this bn layer with the previous conv2d layer
|
|
||||||
conv = a[i - 1]
|
|
||||||
fused = torch_utils.fuse_conv_and_bn(conv, b)
|
|
||||||
a = nn.Sequential(fused, *list(a.children())[i + 1:])
|
|
||||||
break
|
|
||||||
fused_list.append(a)
|
|
||||||
model.module_list = fused_list
|
|
||||||
#model_info(model) # yolov3-spp reduced from 225 to 152 layers
|
|
||||||
|
|
||||||
model.to(device).eval()
|
model.to(device).eval()
|
||||||
|
|
||||||
|
|
15
models.py
15
models.py
|
@ -212,6 +212,21 @@ class Darknet(nn.Module):
|
||||||
io, p = list(zip(*output)) # inference output, training output
|
io, p = list(zip(*output)) # inference output, training output
|
||||||
return torch.cat(io, 1), p
|
return torch.cat(io, 1), p
|
||||||
|
|
||||||
|
def fuse(self):
|
||||||
|
# Fuse Conv2d + BatchNorm2d layers throughout model
|
||||||
|
fused_list = nn.ModuleList()
|
||||||
|
for a in list(self.children())[0]:
|
||||||
|
for i, b in enumerate(a):
|
||||||
|
if isinstance(b, nn.modules.batchnorm.BatchNorm2d):
|
||||||
|
# fuse this bn layer with the previous conv2d layer
|
||||||
|
conv = a[i - 1]
|
||||||
|
fused = torch_utils.fuse_conv_and_bn(conv, b)
|
||||||
|
a = nn.Sequential(fused, *list(a.children())[i + 1:])
|
||||||
|
break
|
||||||
|
fused_list.append(a)
|
||||||
|
self.module_list = fused_list
|
||||||
|
# model_info(self) # yolov3-spp reduced from 225 to 152 layers
|
||||||
|
|
||||||
|
|
||||||
def get_yolo_layers(model):
|
def get_yolo_layers(model):
|
||||||
a = [module_def['type'] == 'yolo' for module_def in model.module_defs]
|
a = [module_def['type'] == 'yolo' for module_def in model.module_defs]
|
||||||
|
|
Loading…
Reference in New Issue