From 1d0a4a3ace94567172e2adbabfba0fa553a07f1a Mon Sep 17 00:00:00 2001 From: glenn-jocher Date: Wed, 3 Jul 2019 14:42:11 +0200 Subject: [PATCH] updates --- models.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/models.py b/models.py index 347264fd..6b71989a 100755 --- a/models.py +++ b/models.py @@ -15,6 +15,7 @@ def create_modules(module_defs): hyperparams = module_defs.pop(0) output_filters = [int(hyperparams['channels'])] module_list = nn.ModuleList() + yolo_index = -1 for i, module_def in enumerate(module_defs): modules = nn.Sequential() @@ -58,6 +59,7 @@ def create_modules(module_defs): modules.add_module('shortcut_%d' % i, EmptyLayer()) elif module_def['type'] == 'yolo': + yolo_index += 1 anchor_idxs = [int(x) for x in module_def['mask'].split(',')] # Extract anchors anchors = [float(x) for x in module_def['anchors'].split(',')] @@ -66,8 +68,7 @@ def create_modules(module_defs): nc = int(module_def['classes']) # number of classes img_size = hyperparams['height'] # Define detection layer - yolo_layer = YOLOLayer(anchors, nc, img_size, cfg=hyperparams['cfg']) - modules.add_module('yolo_%d' % i, yolo_layer) + modules.add_module('yolo_%d' % i, YOLOLayer(anchors, nc, img_size, yolo_index)) # Register module list and number of output filters module_list.append(modules) @@ -99,7 +100,7 @@ class Upsample(nn.Module): class YOLOLayer(nn.Module): - def __init__(self, anchors, nc, img_size, cfg): + def __init__(self, anchors, nc, img_size, yolo_index): super(YOLOLayer, self).__init__() self.anchors = torch.Tensor(anchors) @@ -109,7 +110,7 @@ class YOLOLayer(nn.Module): self.ny = 0 # initialize number of y gridpoints if ONNX_EXPORT: # grids must be computed in __init__ - stride = [32, 16, 8][yolo_layer] # stride of this layer + stride = [32, 16, 8][yolo_index] # stride of this layer nx = int(img_size[1] / stride) # number x grid points ny = int(img_size[0] / stride) # number y grid points create_grids(self, max(img_size), (nx, ny))