auto reverse-strides for yolov4/panet

This commit is contained in:
Glenn Jocher 2020-04-28 15:24:14 -07:00
parent c6ea2b58ea
commit 9cc4951d4f
1 changed files with 6 additions and 4 deletions

View File

@ -5,7 +5,7 @@ from utils.parse_config import *
ONNX_EXPORT = False ONNX_EXPORT = False
def create_modules(module_defs, img_size): def create_modules(module_defs, img_size, cfg):
# Constructs module list of layer blocks from module configuration in module_defs # Constructs module list of layer blocks from module configuration in module_defs
img_size = [img_size] * 2 if isinstance(img_size, int) else img_size # expand if necessary img_size = [img_size] * 2 if isinstance(img_size, int) else img_size # expand if necessary
@ -92,14 +92,16 @@ def create_modules(module_defs, img_size):
elif mdef['type'] == 'yolo': elif mdef['type'] == 'yolo':
yolo_index += 1 yolo_index += 1
stride = [32, 16, 8, 4, 2][yolo_index] # P3-P7 stride stride = [32, 16, 8] # P5, P4, P3 strides
if 'panet' in cfg or 'yolov4' in cfg: # stride order reversed
stride = list(reversed(stride))
layers = mdef['from'] if 'from' in mdef else [] layers = mdef['from'] if 'from' in mdef else []
modules = YOLOLayer(anchors=mdef['anchors'][mdef['mask']], # anchor list modules = YOLOLayer(anchors=mdef['anchors'][mdef['mask']], # anchor list
nc=mdef['classes'], # number of classes nc=mdef['classes'], # number of classes
img_size=img_size, # (416, 416) img_size=img_size, # (416, 416)
yolo_index=yolo_index, # 0, 1, 2... yolo_index=yolo_index, # 0, 1, 2...
layers=layers, # output layers layers=layers, # output layers
stride=stride) stride=stride[yolo_index])
# Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3) # Initialize preceding Conv2d() bias (https://arxiv.org/pdf/1708.02002.pdf section 3.3)
try: try:
@ -221,7 +223,7 @@ class Darknet(nn.Module):
super(Darknet, self).__init__() super(Darknet, self).__init__()
self.module_defs = parse_model_cfg(cfg) self.module_defs = parse_model_cfg(cfg)
self.module_list, self.routs = create_modules(self.module_defs, img_size) self.module_list, self.routs = create_modules(self.module_defs, img_size, cfg)
self.yolo_layers = get_yolo_layers(self) self.yolo_layers = get_yolo_layers(self)
# torch_utils.initialize_weights(self) # torch_utils.initialize_weights(self)