diff --git a/models.py b/models.py index ee9b98c3..049c6584 100755 --- a/models.py +++ b/models.py @@ -127,19 +127,19 @@ class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers http x = x * w[0] # Fusion - nc = x.shape[1] # input channels + nx = x.shape[1] # input channels for i in range(self.n - 1): a = outputs[self.layers[i]] * w[i + 1] if self.weight else outputs[self.layers[i]] # feature to add - ac = a.shape[1] # feature channels - dc = nc - ac # delta channels + na = a.shape[1] # feature channels # Adjust channels - if dc > 0: # slice input - x[:, :ac] = x[:, :ac] + a # or a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a); x = x + a - elif dc < 0: # slice feature - x = x + a[:, :nc] - else: # same shape + if nx == na: # same shape x = x + a + elif nx > na: # slice input + x[:, :na] = x[:, :na] + a # or a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a); x = x + a + else: # slice feature + x = x + a[:, :nx] + return x