updates
This commit is contained in:
parent
b97b88b659
commit
b70e39ab9b
23
models.py
23
models.py
|
@ -133,19 +133,24 @@ class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers http
|
||||||
x = x * w[0]
|
x = x * w[0]
|
||||||
|
|
||||||
# Fusion
|
# Fusion
|
||||||
nc = x.shape[1] # number of channels
|
nc = x.shape[1] # input channels
|
||||||
for i in range(self.n - 1):
|
for i in range(self.n - 1):
|
||||||
a = outputs[self.layers[i]] # feature to add
|
a = outputs[self.layers[i]] # feature to add
|
||||||
dc = nc - a.shape[1] # delta channels
|
ac = a.shape[1] # feature channels
|
||||||
|
dc = nc - ac # delta channels
|
||||||
|
|
||||||
# Adjust channels
|
# Adjust channels
|
||||||
if dc > 0: # pad
|
if dc > 0: # slice input
|
||||||
a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a)
|
# a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a)
|
||||||
elif dc < 0: # slice
|
x[:, :ac] = x[:, :ac] + (a * w[i + 1] if self.weight else a)
|
||||||
a = a[:, :nc]
|
elif dc < 0: # slice feature
|
||||||
|
if self.n == 2:
|
||||||
# Sum
|
return x + (a[:, :nc] * w[i + 1] if self.weight else a[:, :nc])
|
||||||
x = x + a * w[i + 1] if self.weight else x + a
|
x = x + (a[:, :nc] * w[i + 1] if self.weight else a[:, :nc])
|
||||||
|
else: # same shape
|
||||||
|
if self.n == 2:
|
||||||
|
return x + (a * w[i + 1] if self.weight else a)
|
||||||
|
x = x + (a * w[i + 1] if self.weight else a)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue