auto reverse-strides for yolov4/panet
This commit is contained in:
parent
c6ea2b58ea
commit
9cc4951d4f
10
models.py
10
models.py
|
@ -5,7 +5,7 @@ from utils.parse_config import *
|
||||||
ONNX_EXPORT = False
|
ONNX_EXPORT = False
|
||||||
|
|
||||||
|
|
||||||
def create_modules(module_defs, img_size):
|
def create_modules(module_defs, img_size, cfg):
|
||||||
# Constructs module list of layer blocks from module configuration in module_defs
|
# Constructs module list of layer blocks from module configuration in module_defs
|
||||||
|
|
||||||
img_size = [img_size] * 2 if isinstance(img_size, int) else img_size # expand if necessary
|
img_size = [img_size] * 2 if isinstance(img_size, int) else img_size # expand if necessary
|
||||||
|
@ -92,14 +92,16 @@ def create_modules(module_defs, img_size):
|
||||||
|
|
||||||
elif mdef['type'] == 'yolo':
|
elif mdef['type'] == 'yolo':
|
||||||
yolo_index += 1
|
yolo_index += 1
|
||||||
stride = [32, 16, 8, 4, 2][yolo_index] # P3-P7 stride
|
stride = [32, 16, 8] # P5, P4, P3 strides
|
||||||
|
if 'panet' in cfg or 'yolov4' in cfg: # stride order reversed
|
||||||
|
stride = list(reversed(stride))
|
||||||
layers = mdef['from'] if 'from' in mdef else []
|
layers = mdef['from'] if 'from' in mdef else []
|
||||||
modules = YOLOLayer(anchors=mdef['anchors'][mdef['mask']], # anchor list
|
modules = YOLOLayer(anchors=mdef['anchors'][mdef['mask']], # anchor list
|
||||||
nc=mdef['classes'], # number of classes
|
nc=mdef['classes'], # number of classes
|
||||||
img_size=img_size, # (416, 416)
|
img_size=img_size, # (416, 416)
|
||||||
yolo_index=yolo_index, # 0, 1, 2...
|
yolo_index=yolo_index, # 0, 1, 2...
|
||||||
layers=layers, # output layers
|
layers=layers, # output layers
|
||||||
stride=stride)
|
stride=stride[yolo_index])
|
||||||
|
|
||||||
# Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3)
|
# Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3)
|
||||||
try:
|
try:
|
||||||
|
@ -221,7 +223,7 @@ class Darknet(nn.Module):
|
||||||
super(Darknet, self).__init__()
|
super(Darknet, self).__init__()
|
||||||
|
|
||||||
self.module_defs = parse_model_cfg(cfg)
|
self.module_defs = parse_model_cfg(cfg)
|
||||||
self.module_list, self.routs = create_modules(self.module_defs, img_size)
|
self.module_list, self.routs = create_modules(self.module_defs, img_size, cfg)
|
||||||
self.yolo_layers = get_yolo_layers(self)
|
self.yolo_layers = get_yolo_layers(self)
|
||||||
# torch_utils.initialize_weights(self)
|
# torch_utils.initialize_weights(self)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue