cleanup
This commit is contained in:
		
							parent
							
								
									9ea856242f
								
							
						
					
					
						commit
						bf1061c146
					
				|  | @ -42,7 +42,7 @@ class WeightedFeatureFusion(nn.Module):  # weighted sum of 2 or more layers http | ||||||
|         self.weight = weight  # apply weights boolean |         self.weight = weight  # apply weights boolean | ||||||
|         self.n = len(layers) + 1  # number of layers |         self.n = len(layers) + 1  # number of layers | ||||||
|         if weight: |         if weight: | ||||||
|             self.w = torch.nn.Parameter(torch.zeros(self.n), requires_grad=True)  # layer weights |             self.w = nn.Parameter(torch.zeros(self.n), requires_grad=True)  # layer weights | ||||||
| 
 | 
 | ||||||
|     def forward(self, x, outputs): |     def forward(self, x, outputs): | ||||||
|         # Weights |         # Weights | ||||||
|  | @ -83,13 +83,13 @@ class MixConv2d(nn.Module):  # MixConv: Mixed Depthwise Convolutional Kernels ht | ||||||
|             a[0] = 1 |             a[0] = 1 | ||||||
|             ch = np.linalg.lstsq(a, b, rcond=None)[0].round().astype(int)  # solve for equal weight indices, ax = b |             ch = np.linalg.lstsq(a, b, rcond=None)[0].round().astype(int)  # solve for equal weight indices, ax = b | ||||||
| 
 | 
 | ||||||
|         self.m = nn.ModuleList([torch.nn.Conv2d(in_channels=in_ch, |         self.m = nn.ModuleList([nn.Conv2d(in_channels=in_ch, | ||||||
|                                                 out_channels=ch[g], |                                           out_channels=ch[g], | ||||||
|                                                 kernel_size=k[g], |                                           kernel_size=k[g], | ||||||
|                                                 stride=stride, |                                           stride=stride, | ||||||
|                                                 padding=k[g] // 2,  # 'same' pad |                                           padding=k[g] // 2,  # 'same' pad | ||||||
|                                                 dilation=dilation, |                                           dilation=dilation, | ||||||
|                                                 bias=bias) for g in range(groups)]) |                                           bias=bias) for g in range(groups)]) | ||||||
| 
 | 
 | ||||||
|     def forward(self, x): |     def forward(self, x): | ||||||
|         return torch.cat([m(x) for m in self.m], 1) |         return torch.cat([m(x) for m in self.m], 1) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue