feature fusion update

This commit is contained in:
Glenn Jocher 2020-03-30 17:53:17 -07:00
parent 108334db29
commit ac2aa56e0a
1 changed files with 8 additions and 8 deletions

View File

@ -127,19 +127,19 @@ class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers http
x = x * w[0] x = x * w[0]
# Fusion # Fusion
nc = x.shape[1] # input channels nx = x.shape[1] # input channels
for i in range(self.n - 1): for i in range(self.n - 1):
a = outputs[self.layers[i]] * w[i + 1] if self.weight else outputs[self.layers[i]] # feature to add a = outputs[self.layers[i]] * w[i + 1] if self.weight else outputs[self.layers[i]] # feature to add
ac = a.shape[1] # feature channels na = a.shape[1] # feature channels
dc = nc - ac # delta channels
# Adjust channels # Adjust channels
if dc > 0: # slice input if nx == na: # same shape
x[:, :ac] = x[:, :ac] + a # or a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a); x = x + a
elif dc < 0: # slice feature
x = x + a[:, :nc]
else: # same shape
x = x + a x = x + a
elif nx > na: # slice input
x[:, :na] = x[:, :na] + a # or a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a); x = x + a
else: # slice feature
x = x + a[:, :nx]
return x return x