return get_yolo_layers()
This commit is contained in:
parent
b8574add37
commit
ca3a9fcb0b
|
@ -224,7 +224,7 @@ class Darknet(nn.Module):
|
||||||
|
|
||||||
self.module_defs = parse_model_cfg(cfg)
|
self.module_defs = parse_model_cfg(cfg)
|
||||||
self.module_list, self.routs = create_modules(self.module_defs, img_size)
|
self.module_list, self.routs = create_modules(self.module_defs, img_size)
|
||||||
self.yolo_layers = torch_utils.find_modules(self, mclass=YOLOLayer)
|
self.yolo_layers = get_yolo_layers(self)
|
||||||
# torch_utils.initialize_weights(self)
|
# torch_utils.initialize_weights(self)
|
||||||
|
|
||||||
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
|
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
|
||||||
|
@ -333,6 +333,10 @@ class Darknet(nn.Module):
|
||||||
torch_utils.model_info(self, verbose)
|
torch_utils.model_info(self, verbose)
|
||||||
|
|
||||||
|
|
||||||
|
def get_yolo_layers(model):
|
||||||
|
return [i for i, m in enumerate(model.module_list) if m.__class__.__name__ == 'YOLOLayer'] # [89, 101, 113]
|
||||||
|
|
||||||
|
|
||||||
def load_darknet_weights(self, weights, cutoff=-1):
|
def load_darknet_weights(self, weights, cutoff=-1):
|
||||||
# Parses and loads the weights stored in 'weights'
|
# Parses and loads the weights stored in 'weights'
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue