This commit is contained in:
Glenn Jocher 2019-08-12 13:49:38 +02:00
parent 616bbdb435
commit 7fb64dbf67
1 changed files with 10 additions and 11 deletions

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
ONNX_EXPORT = False ONNX_EXPORT = False
def create_modules(module_defs): def create_modules(module_defs, img_size):
""" """
Constructs module list of layer blocks from module configuration in module_defs Constructs module list of layer blocks from module configuration in module_defs
""" """
@ -75,7 +75,7 @@ def create_modules(module_defs):
a = [(a[i], a[i + 1]) for i in range(0, len(a), 2)] a = [(a[i], a[i + 1]) for i in range(0, len(a), 2)]
modules = YOLOLayer(anchors=[a[i] for i in mask], # anchor list modules = YOLOLayer(anchors=[a[i] for i in mask], # anchor list
nc=int(mdef['classes']), # number of classes nc=int(mdef['classes']), # number of classes
img_size=hyperparams['height'], # 416 img_size=img_size, # (416, 416)
yolo_index=yolo_index) # 0, 1 or 2 yolo_index=yolo_index) # 0, 1 or 2
else: else:
print('Warning: Unrecognized Layer Type: ' + mdef['type']) print('Warning: Unrecognized Layer Type: ' + mdef['type'])
@ -175,8 +175,7 @@ class Darknet(nn.Module):
super(Darknet, self).__init__() super(Darknet, self).__init__()
self.module_defs = parse_model_cfg(cfg) self.module_defs = parse_model_cfg(cfg)
self.module_defs[0]['height'] = img_size self.module_list, self.routs = create_modules(self.module_defs, img_size)
self.module_list, self.routs = create_modules(self.module_defs)
self.yolo_layers = get_yolo_layers(self) self.yolo_layers = get_yolo_layers(self)
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346 # Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
@ -193,16 +192,16 @@ class Darknet(nn.Module):
if mtype in ['convolutional', 'upsample', 'maxpool']: if mtype in ['convolutional', 'upsample', 'maxpool']:
x = module(x) x = module(x)
elif mtype == 'route': elif mtype == 'route':
layer_i = [int(x) for x in mdef['layers'].split(',')] layers = [int(x) for x in mdef['layers'].split(',')]
if len(layer_i) == 1: if len(layers) == 1:
x = layer_outputs[layer_i[0]] x = layer_outputs[layers[0]]
else: else:
try: try:
x = torch.cat([layer_outputs[i] for i in layer_i], 1) x = torch.cat([layer_outputs[i] for i in layers], 1)
except: # apply stride 2 for darknet reorg layer except: # apply stride 2 for darknet reorg layer
layer_outputs[layer_i[1]] = F.interpolate(layer_outputs[layer_i[1]], scale_factor=[0.5, 0.5]) layer_outputs[layers[1]] = F.interpolate(layer_outputs[layers[1]], scale_factor=[0.5, 0.5])
x = torch.cat([layer_outputs[i] for i in layer_i], 1) x = torch.cat([layer_outputs[i] for i in layers], 1)
# print(''), [print(layer_outputs[i].shape) for i in layer_i], print(x.shape) # print(''), [print(layer_outputs[i].shape) for i in layers], print(x.shape)
elif mtype == 'shortcut': elif mtype == 'shortcut':
x = x + layer_outputs[int(mdef['from'])] x = x + layer_outputs[int(mdef['from'])]
elif mtype == 'yolo': elif mtype == 'yolo':