diff --git a/models.py b/models.py index 7884e603..616cb924 100755 --- a/models.py +++ b/models.py @@ -224,7 +224,7 @@ class Darknet(nn.Module): self.module_defs = parse_model_cfg(cfg) self.module_list, self.routs = create_modules(self.module_defs, img_size) - self.yolo_layers = get_yolo_layers(self) + self.yolo_layers = torch_utils.find_modules(self, mclass=YOLOLayer) # torch_utils.initialize_weights(self) # Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346 @@ -333,10 +333,6 @@ class Darknet(nn.Module): torch_utils.model_info(self, verbose) -def get_yolo_layers(model): - return [i for i, m in enumerate(model.module_list) if m.__class__.__name__ == 'YOLOLayer'] # [89, 101, 113] - - def load_darknet_weights(self, weights, cutoff=-1): # Parses and loads the weights stored in 'weights' diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 9d2491f3..0e1ade3a 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -59,6 +59,11 @@ def initialize_weights(model): m.momentum = 0.03 +def find_modules(model, mclass=nn.Conv2d): + # finds layer indices matching module class 'mclass' + return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)] + + def fuse_conv_and_bn(conv, bn): # https://tehnokv.com/posts/fusing-batchnorm-and-conv/ with torch.no_grad():