Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
c1bb037cbe
|
@ -15,7 +15,7 @@ def create_modules(module_defs):
|
|||
hyperparams = module_defs.pop(0)
|
||||
output_filters = [int(hyperparams['channels'])]
|
||||
module_list = nn.ModuleList()
|
||||
yolo_layer_count = 0
|
||||
|
||||
for i, module_def in enumerate(module_defs):
|
||||
modules = nn.Sequential()
|
||||
|
||||
|
@ -66,9 +66,8 @@ def create_modules(module_defs):
|
|||
nc = int(module_def['classes']) # number of classes
|
||||
img_size = hyperparams['height']
|
||||
# Define detection layer
|
||||
yolo_layer = YOLOLayer(anchors, nc, img_size, yolo_layer_count, cfg=hyperparams['cfg'])
|
||||
yolo_layer = YOLOLayer(anchors, nc, img_size, cfg=hyperparams['cfg'])
|
||||
modules.add_module('yolo_%d' % i, yolo_layer)
|
||||
yolo_layer_count += 1
|
||||
|
||||
# Register module list and number of output filters
|
||||
module_list.append(modules)
|
||||
|
@ -100,7 +99,7 @@ class Upsample(nn.Module):
|
|||
|
||||
|
||||
class YOLOLayer(nn.Module):
|
||||
def __init__(self, anchors, nc, img_size, yolo_layer, cfg):
|
||||
def __init__(self, anchors, nc, img_size, cfg):
|
||||
super(YOLOLayer, self).__init__()
|
||||
|
||||
self.anchors = torch.Tensor(anchors)
|
||||
|
|
Loading…
Reference in New Issue