This commit is contained in:
Glenn Jocher 2020-04-16 16:12:23 -07:00
parent 9ea856242f
commit bf1061c146
1 changed files with 8 additions and 8 deletions

View File

@ -42,7 +42,7 @@ class WeightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers http
self.weight = weight # apply weights boolean self.weight = weight # apply weights boolean
self.n = len(layers) + 1 # number of layers self.n = len(layers) + 1 # number of layers
if weight: if weight:
self.w = torch.nn.Parameter(torch.zeros(self.n), requires_grad=True) # layer weights self.w = nn.Parameter(torch.zeros(self.n), requires_grad=True) # layer weights
def forward(self, x, outputs): def forward(self, x, outputs):
# Weights # Weights
@ -83,7 +83,7 @@ class MixConv2d(nn.Module): # MixConv: Mixed Depthwise Convolutional Kernels ht
a[0] = 1 a[0] = 1
ch = np.linalg.lstsq(a, b, rcond=None)[0].round().astype(int) # solve for equal weight indices, ax = b ch = np.linalg.lstsq(a, b, rcond=None)[0].round().astype(int) # solve for equal weight indices, ax = b
self.m = nn.ModuleList([torch.nn.Conv2d(in_channels=in_ch, self.m = nn.ModuleList([nn.Conv2d(in_channels=in_ch,
out_channels=ch[g], out_channels=ch[g],
kernel_size=k[g], kernel_size=k[g],
stride=stride, stride=stride,