updates
This commit is contained in:
		
							parent
							
								
									a971b33b74
								
							
						
					
					
						commit
						b022648716
					
				
							
								
								
									
										16
									
								
								models.py
								
								
								
								
							
							
						
						
									
										16
									
								
								models.py
								
								
								
								
							|  | @ -118,19 +118,21 @@ def create_modules(module_defs, img_size, arc): | ||||||
|     return module_list, routs |     return module_list, routs | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class weightedFeatureFusion(nn.Module):  # weighted sum of layers https://arxiv.org/abs/1911.09070 | class weightedFeatureFusion(nn.Module):  # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 | ||||||
|     def __init__(self, layers): |     def __init__(self, layers): | ||||||
|         super(weightedFeatureFusion, self).__init__() |         super(weightedFeatureFusion, self).__init__() | ||||||
|         self.n = len(layers)  # number of layers |         self.n = len(layers) + 1  # number of layers | ||||||
|         self.layers = layers  # layer indices |         self.layers = layers  # layer indices | ||||||
|         self.w = torch.nn.Parameter(torch.zeros(self.n + 1))  # layer weights |         self.w = torch.nn.Parameter(torch.zeros(self.n))  # layer weights | ||||||
| 
 | 
 | ||||||
|     def forward(self, x, outputs): |     def forward(self, x, outputs): | ||||||
|         w = torch.sigmoid(self.w) * (2 / self.n)  # sigmoid weights (0-1) |         w = torch.sigmoid(self.w) * (2 / self.n)  # sigmoid weights (0-1) | ||||||
|         x = x * w[0] |         if self.n == 2: | ||||||
|         for i in range(self.n): |             return x * w[0] + outputs[self.layers[0]] * w[1] | ||||||
|             x = x + outputs[self.layers[i]] * w[i + 1] |         elif self.n == 3: | ||||||
|         return x |             return x * w[0] + outputs[self.layers[0]] * w[1] + outputs[self.layers[1]] * w[2] | ||||||
|  |         else: | ||||||
|  |             raise ValueError('weightedFeatureFusion() supports up to 3 layer inputs, %g attempted' % self.n) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SwishImplementation(torch.autograd.Function): | class SwishImplementation(torch.autograd.Function): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue