feature fusion update
This commit is contained in:
parent
108334db29
commit
ac2aa56e0a
16
models.py
16
models.py
|
@ -127,19 +127,19 @@ class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers http
|
|||
x = x * w[0]
|
||||
|
||||
# Fusion
|
||||
nc = x.shape[1] # input channels
|
||||
nx = x.shape[1] # input channels
|
||||
for i in range(self.n - 1):
|
||||
a = outputs[self.layers[i]] * w[i + 1] if self.weight else outputs[self.layers[i]] # feature to add
|
||||
ac = a.shape[1] # feature channels
|
||||
dc = nc - ac # delta channels
|
||||
na = a.shape[1] # feature channels
|
||||
|
||||
# Adjust channels
|
||||
if dc > 0: # slice input
|
||||
x[:, :ac] = x[:, :ac] + a # or a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a); x = x + a
|
||||
elif dc < 0: # slice feature
|
||||
x = x + a[:, :nc]
|
||||
else: # same shape
|
||||
if nx == na: # same shape
|
||||
x = x + a
|
||||
elif nx > na: # slice input
|
||||
x[:, :na] = x[:, :na] + a # or a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a); x = x + a
|
||||
else: # slice feature
|
||||
x = x + a[:, :nx]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue