Merge remote-tracking branch 'origin/master'

This commit is contained in:
Glenn Jocher 2019-06-26 11:28:06 +02:00
commit c1bb037cbe
1 changed files with 3 additions and 4 deletions

View File

@ -15,7 +15,7 @@ def create_modules(module_defs):
hyperparams = module_defs.pop(0) hyperparams = module_defs.pop(0)
output_filters = [int(hyperparams['channels'])] output_filters = [int(hyperparams['channels'])]
module_list = nn.ModuleList() module_list = nn.ModuleList()
yolo_layer_count = 0
for i, module_def in enumerate(module_defs): for i, module_def in enumerate(module_defs):
modules = nn.Sequential() modules = nn.Sequential()
@ -66,9 +66,8 @@ def create_modules(module_defs):
nc = int(module_def['classes']) # number of classes nc = int(module_def['classes']) # number of classes
img_size = hyperparams['height'] img_size = hyperparams['height']
# Define detection layer # Define detection layer
yolo_layer = YOLOLayer(anchors, nc, img_size, yolo_layer_count, cfg=hyperparams['cfg']) yolo_layer = YOLOLayer(anchors, nc, img_size, cfg=hyperparams['cfg'])
modules.add_module('yolo_%d' % i, yolo_layer) modules.add_module('yolo_%d' % i, yolo_layer)
yolo_layer_count += 1
# Register module list and number of output filters # Register module list and number of output filters
module_list.append(modules) module_list.append(modules)
@ -100,7 +99,7 @@ class Upsample(nn.Module):
class YOLOLayer(nn.Module): class YOLOLayer(nn.Module):
def __init__(self, anchors, nc, img_size, yolo_layer, cfg): def __init__(self, anchors, nc, img_size, cfg):
super(YOLOLayer, self).__init__() super(YOLOLayer, self).__init__()
self.anchors = torch.Tensor(anchors) self.anchors = torch.Tensor(anchors)