Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
c1bb037cbe
|
@ -15,7 +15,7 @@ def create_modules(module_defs):
|
||||||
hyperparams = module_defs.pop(0)
|
hyperparams = module_defs.pop(0)
|
||||||
output_filters = [int(hyperparams['channels'])]
|
output_filters = [int(hyperparams['channels'])]
|
||||||
module_list = nn.ModuleList()
|
module_list = nn.ModuleList()
|
||||||
yolo_layer_count = 0
|
|
||||||
for i, module_def in enumerate(module_defs):
|
for i, module_def in enumerate(module_defs):
|
||||||
modules = nn.Sequential()
|
modules = nn.Sequential()
|
||||||
|
|
||||||
|
@ -66,9 +66,8 @@ def create_modules(module_defs):
|
||||||
nc = int(module_def['classes']) # number of classes
|
nc = int(module_def['classes']) # number of classes
|
||||||
img_size = hyperparams['height']
|
img_size = hyperparams['height']
|
||||||
# Define detection layer
|
# Define detection layer
|
||||||
yolo_layer = YOLOLayer(anchors, nc, img_size, yolo_layer_count, cfg=hyperparams['cfg'])
|
yolo_layer = YOLOLayer(anchors, nc, img_size, cfg=hyperparams['cfg'])
|
||||||
modules.add_module('yolo_%d' % i, yolo_layer)
|
modules.add_module('yolo_%d' % i, yolo_layer)
|
||||||
yolo_layer_count += 1
|
|
||||||
|
|
||||||
# Register module list and number of output filters
|
# Register module list and number of output filters
|
||||||
module_list.append(modules)
|
module_list.append(modules)
|
||||||
|
@ -100,7 +99,7 @@ class Upsample(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class YOLOLayer(nn.Module):
|
class YOLOLayer(nn.Module):
|
||||||
def __init__(self, anchors, nc, img_size, yolo_layer, cfg):
|
def __init__(self, anchors, nc, img_size, cfg):
|
||||||
super(YOLOLayer, self).__init__()
|
super(YOLOLayer, self).__init__()
|
||||||
|
|
||||||
self.anchors = torch.Tensor(anchors)
|
self.anchors = torch.Tensor(anchors)
|
||||||
|
|
Loading…
Reference in New Issue