This commit is contained in:
Glenn Jocher 2020-02-22 12:54:09 -08:00
parent b70e39ab9b
commit 2d9bc62526
1 changed files with 6 additions and 6 deletions

View File

@ -135,22 +135,22 @@ class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers http
# Fusion # Fusion
nc = x.shape[1] # input channels nc = x.shape[1] # input channels
for i in range(self.n - 1): for i in range(self.n - 1):
a = outputs[self.layers[i]] # feature to add a = outputs[self.layers[i]] * w[i + 1] if self.weight else outputs[self.layers[i]] # feature to add
ac = a.shape[1] # feature channels ac = a.shape[1] # feature channels
dc = nc - ac # delta channels dc = nc - ac # delta channels
# Adjust channels # Adjust channels
if dc > 0: # slice input if dc > 0: # slice input
# a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a) # a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a)
x[:, :ac] = x[:, :ac] + (a * w[i + 1] if self.weight else a) x[:, :ac] = x[:, :ac] + a
elif dc < 0: # slice feature elif dc < 0: # slice feature
if self.n == 2: if self.n == 2:
return x + (a[:, :nc] * w[i + 1] if self.weight else a[:, :nc]) return x + a[:, :nc]
x = x + (a[:, :nc] * w[i + 1] if self.weight else a[:, :nc]) x = x + a[:, :nc]
else: # same shape else: # same shape
if self.n == 2: if self.n == 2:
return x + (a * w[i + 1] if self.weight else a) return x + a
x = x + (a * w[i + 1] if self.weight else a) x = x + a
return x return x