cleanup
This commit is contained in:
parent
9ea856242f
commit
bf1061c146
|
@ -42,7 +42,7 @@ class WeightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers http
|
||||||
self.weight = weight # apply weights boolean
|
self.weight = weight # apply weights boolean
|
||||||
self.n = len(layers) + 1 # number of layers
|
self.n = len(layers) + 1 # number of layers
|
||||||
if weight:
|
if weight:
|
||||||
self.w = torch.nn.Parameter(torch.zeros(self.n), requires_grad=True) # layer weights
|
self.w = nn.Parameter(torch.zeros(self.n), requires_grad=True) # layer weights
|
||||||
|
|
||||||
def forward(self, x, outputs):
|
def forward(self, x, outputs):
|
||||||
# Weights
|
# Weights
|
||||||
|
@ -83,7 +83,7 @@ class MixConv2d(nn.Module): # MixConv: Mixed Depthwise Convolutional Kernels ht
|
||||||
a[0] = 1
|
a[0] = 1
|
||||||
ch = np.linalg.lstsq(a, b, rcond=None)[0].round().astype(int) # solve for equal weight indices, ax = b
|
ch = np.linalg.lstsq(a, b, rcond=None)[0].round().astype(int) # solve for equal weight indices, ax = b
|
||||||
|
|
||||||
self.m = nn.ModuleList([torch.nn.Conv2d(in_channels=in_ch,
|
self.m = nn.ModuleList([nn.Conv2d(in_channels=in_ch,
|
||||||
out_channels=ch[g],
|
out_channels=ch[g],
|
||||||
kernel_size=k[g],
|
kernel_size=k[g],
|
||||||
stride=stride,
|
stride=stride,
|
||||||
|
|
Loading…
Reference in New Issue