diff --git a/models.py b/models.py index 6ebfa495..6a3a4ea7 100755 --- a/models.py +++ b/models.py @@ -135,22 +135,22 @@ class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers http # Fusion nc = x.shape[1] # input channels for i in range(self.n - 1): - a = outputs[self.layers[i]] # feature to add + a = outputs[self.layers[i]] * w[i + 1] if self.weight else outputs[self.layers[i]] # feature to add ac = a.shape[1] # feature channels dc = nc - ac # delta channels # Adjust channels if dc > 0: # slice input # a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a) - x[:, :ac] = x[:, :ac] + (a * w[i + 1] if self.weight else a) + x[:, :ac] = x[:, :ac] + a elif dc < 0: # slice feature if self.n == 2: - return x + (a[:, :nc] * w[i + 1] if self.weight else a[:, :nc]) - x = x + (a[:, :nc] * w[i + 1] if self.weight else a[:, :nc]) + return x + a[:, :nc] + x = x + a[:, :nc] else: # same shape if self.n == 2: - return x + (a * w[i + 1] if self.weight else a) - x = x + (a * w[i + 1] if self.weight else a) + return x + a + x = x + a return x