This commit is contained in:
glenn-jocher 2019-07-03 14:42:11 +02:00
parent a8cf64af31
commit 1d0a4a3ace
1 changed files with 5 additions and 4 deletions

View File

@ -15,6 +15,7 @@ def create_modules(module_defs):
hyperparams = module_defs.pop(0) hyperparams = module_defs.pop(0)
output_filters = [int(hyperparams['channels'])] output_filters = [int(hyperparams['channels'])]
module_list = nn.ModuleList() module_list = nn.ModuleList()
yolo_index = -1
for i, module_def in enumerate(module_defs): for i, module_def in enumerate(module_defs):
modules = nn.Sequential() modules = nn.Sequential()
@ -58,6 +59,7 @@ def create_modules(module_defs):
modules.add_module('shortcut_%d' % i, EmptyLayer()) modules.add_module('shortcut_%d' % i, EmptyLayer())
elif module_def['type'] == 'yolo': elif module_def['type'] == 'yolo':
yolo_index += 1
anchor_idxs = [int(x) for x in module_def['mask'].split(',')] anchor_idxs = [int(x) for x in module_def['mask'].split(',')]
# Extract anchors # Extract anchors
anchors = [float(x) for x in module_def['anchors'].split(',')] anchors = [float(x) for x in module_def['anchors'].split(',')]
@ -66,8 +68,7 @@ def create_modules(module_defs):
nc = int(module_def['classes']) # number of classes nc = int(module_def['classes']) # number of classes
img_size = hyperparams['height'] img_size = hyperparams['height']
# Define detection layer # Define detection layer
yolo_layer = YOLOLayer(anchors, nc, img_size, cfg=hyperparams['cfg']) modules.add_module('yolo_%d' % i, YOLOLayer(anchors, nc, img_size, yolo_index))
modules.add_module('yolo_%d' % i, yolo_layer)
# Register module list and number of output filters # Register module list and number of output filters
module_list.append(modules) module_list.append(modules)
@ -99,7 +100,7 @@ class Upsample(nn.Module):
class YOLOLayer(nn.Module): class YOLOLayer(nn.Module):
def __init__(self, anchors, nc, img_size, cfg): def __init__(self, anchors, nc, img_size, yolo_index):
super(YOLOLayer, self).__init__() super(YOLOLayer, self).__init__()
self.anchors = torch.Tensor(anchors) self.anchors = torch.Tensor(anchors)
@ -109,7 +110,7 @@ class YOLOLayer(nn.Module):
self.ny = 0 # initialize number of y gridpoints self.ny = 0 # initialize number of y gridpoints
if ONNX_EXPORT: # grids must be computed in __init__ if ONNX_EXPORT: # grids must be computed in __init__
stride = [32, 16, 8][yolo_layer] # stride of this layer stride = [32, 16, 8][yolo_index] # stride of this layer
nx = int(img_size[1] / stride) # number x grid points nx = int(img_size[1] / stride) # number x grid points
ny = int(img_size[0] / stride) # number y grid points ny = int(img_size[0] / stride) # number y grid points
create_grids(self, max(img_size), (nx, ny)) create_grids(self, max(img_size), (nx, ny))