updates
This commit is contained in:
		
							parent
							
								
									a8cf64af31
								
							
						
					
					
						commit
						1d0a4a3ace
					
				| 
						 | 
				
			
			@ -15,6 +15,7 @@ def create_modules(module_defs):
 | 
			
		|||
    hyperparams = module_defs.pop(0)
 | 
			
		||||
    output_filters = [int(hyperparams['channels'])]
 | 
			
		||||
    module_list = nn.ModuleList()
 | 
			
		||||
    yolo_index = -1
 | 
			
		||||
 | 
			
		||||
    for i, module_def in enumerate(module_defs):
 | 
			
		||||
        modules = nn.Sequential()
 | 
			
		||||
| 
						 | 
				
			
			@ -58,6 +59,7 @@ def create_modules(module_defs):
 | 
			
		|||
            modules.add_module('shortcut_%d' % i, EmptyLayer())
 | 
			
		||||
 | 
			
		||||
        elif module_def['type'] == 'yolo':
 | 
			
		||||
            yolo_index += 1
 | 
			
		||||
            anchor_idxs = [int(x) for x in module_def['mask'].split(',')]
 | 
			
		||||
            # Extract anchors
 | 
			
		||||
            anchors = [float(x) for x in module_def['anchors'].split(',')]
 | 
			
		||||
| 
						 | 
				
			
			@ -66,8 +68,7 @@ def create_modules(module_defs):
 | 
			
		|||
            nc = int(module_def['classes'])  # number of classes
 | 
			
		||||
            img_size = hyperparams['height']
 | 
			
		||||
            # Define detection layer
 | 
			
		||||
            yolo_layer = YOLOLayer(anchors, nc, img_size, cfg=hyperparams['cfg'])
 | 
			
		||||
            modules.add_module('yolo_%d' % i, yolo_layer)
 | 
			
		||||
            modules.add_module('yolo_%d' % i, YOLOLayer(anchors, nc, img_size, yolo_index))
 | 
			
		||||
 | 
			
		||||
        # Register module list and number of output filters
 | 
			
		||||
        module_list.append(modules)
 | 
			
		||||
| 
						 | 
				
			
			@ -99,7 +100,7 @@ class Upsample(nn.Module):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class YOLOLayer(nn.Module):
 | 
			
		||||
    def __init__(self, anchors, nc, img_size, cfg):
 | 
			
		||||
    def __init__(self, anchors, nc, img_size, yolo_index):
 | 
			
		||||
        super(YOLOLayer, self).__init__()
 | 
			
		||||
 | 
			
		||||
        self.anchors = torch.Tensor(anchors)
 | 
			
		||||
| 
						 | 
				
			
			@ -109,7 +110,7 @@ class YOLOLayer(nn.Module):
 | 
			
		|||
        self.ny = 0  # initialize number of y gridpoints
 | 
			
		||||
 | 
			
		||||
        if ONNX_EXPORT:  # grids must be computed in __init__
 | 
			
		||||
            stride = [32, 16, 8][yolo_layer]  # stride of this layer
 | 
			
		||||
            stride = [32, 16, 8][yolo_index]  # stride of this layer
 | 
			
		||||
            nx = int(img_size[1] / stride)  # number x grid points
 | 
			
		||||
            ny = int(img_size[0] / stride)  # number y grid points
 | 
			
		||||
            create_grids(self, max(img_size), (nx, ny))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue