new find_modules() fcn

This commit is contained in:
Glenn Jocher 2020-04-13 17:48:30 -07:00
parent 77e6bdd3c1
commit b8574add37
2 changed files with 6 additions and 5 deletions

View File

@ -224,7 +224,7 @@ class Darknet(nn.Module):
self.module_defs = parse_model_cfg(cfg) self.module_defs = parse_model_cfg(cfg)
self.module_list, self.routs = create_modules(self.module_defs, img_size) self.module_list, self.routs = create_modules(self.module_defs, img_size)
self.yolo_layers = get_yolo_layers(self) self.yolo_layers = torch_utils.find_modules(self, mclass=YOLOLayer)
# torch_utils.initialize_weights(self) # torch_utils.initialize_weights(self)
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346 # Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
@ -333,10 +333,6 @@ class Darknet(nn.Module):
torch_utils.model_info(self, verbose) torch_utils.model_info(self, verbose)
def get_yolo_layers(model):
return [i for i, m in enumerate(model.module_list) if m.__class__.__name__ == 'YOLOLayer'] # [89, 101, 113]
def load_darknet_weights(self, weights, cutoff=-1): def load_darknet_weights(self, weights, cutoff=-1):
# Parses and loads the weights stored in 'weights' # Parses and loads the weights stored in 'weights'

View File

@ -59,6 +59,11 @@ def initialize_weights(model):
m.momentum = 0.03 m.momentum = 0.03
def find_modules(model, mclass=nn.Conv2d):
# finds layer indices matching module class 'mclass'
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
def fuse_conv_and_bn(conv, bn): def fuse_conv_and_bn(conv, bn):
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/ # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
with torch.no_grad(): with torch.no_grad():