diff --git a/models.py b/models.py index 7a870e0a..6ebfa495 100755 --- a/models.py +++ b/models.py @@ -133,19 +133,24 @@ class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers http x = x * w[0] # Fusion - nc = x.shape[1] # number of channels + nc = x.shape[1] # input channels for i in range(self.n - 1): a = outputs[self.layers[i]] # feature to add - dc = nc - a.shape[1] # delta channels + ac = a.shape[1] # feature channels + dc = nc - ac # delta channels # Adjust channels - if dc > 0: # pad - a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a) - elif dc < 0: # slice - a = a[:, :nc] - - # Sum - x = x + a * w[i + 1] if self.weight else x + a + if dc > 0: # slice input + # a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a) + x[:, :ac] = x[:, :ac] + (a * w[i + 1] if self.weight else a) + elif dc < 0: # slice feature + if self.n == 2: + return x + (a[:, :nc] * w[i + 1] if self.weight else a[:, :nc]) + x = x + (a[:, :nc] * w[i + 1] if self.weight else a[:, :nc]) + else: # same shape + if self.n == 2: + return x + (a * w[i + 1] if self.weight else a) + x = x + (a * w[i + 1] if self.weight else a) return x