This commit is contained in:
Glenn Jocher 2020-04-16 16:12:23 -07:00
parent 9ea856242f
commit bf1061c146
1 changed files with 8 additions and 8 deletions

View File

@ -42,7 +42,7 @@ class WeightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers http
self.weight = weight # apply weights boolean
self.n = len(layers) + 1 # number of layers
if weight:
self.w = torch.nn.Parameter(torch.zeros(self.n), requires_grad=True) # layer weights
self.w = nn.Parameter(torch.zeros(self.n), requires_grad=True) # layer weights
def forward(self, x, outputs):
# Weights
@ -83,13 +83,13 @@ class MixConv2d(nn.Module): # MixConv: Mixed Depthwise Convolutional Kernels ht
a[0] = 1
ch = np.linalg.lstsq(a, b, rcond=None)[0].round().astype(int) # solve for equal weight indices, ax = b
self.m = nn.ModuleList([torch.nn.Conv2d(in_channels=in_ch,
out_channels=ch[g],
kernel_size=k[g],
stride=stride,
padding=k[g] // 2, # 'same' pad
dilation=dilation,
bias=bias) for g in range(groups)])
self.m = nn.ModuleList([nn.Conv2d(in_channels=in_ch,
out_channels=ch[g],
kernel_size=k[g],
stride=stride,
padding=k[g] // 2, # 'same' pad
dilation=dilation,
bias=bias) for g in range(groups)])
def forward(self, x):
return torch.cat([m(x) for m in self.m], 1)