diff --git a/models.py b/models.py index 3e5947f4..4de0ad97 100755 --- a/models.py +++ b/models.py @@ -121,24 +121,34 @@ def create_modules(module_defs, img_size, arc): class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 def __init__(self, layers, weight=False): super(weightedFeatureFusion, self).__init__() - self.n = len(layers) + 1 # number of layers self.layers = layers # layer indices self.weight = weight # apply weights boolean + self.n = len(layers) + 1 # number of layers if weight: self.w = torch.nn.Parameter(torch.zeros(self.n)) # layer weights def forward(self, x, outputs): + # Weights if self.weight: w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1) - if self.n == 2: - return x * w[0] + outputs[self.layers[0]] * w[1] - elif self.n == 3: - return x * w[0] + outputs[self.layers[0]] * w[1] + outputs[self.layers[1]] * w[2] - else: - if self.n == 2: - return x + outputs[self.layers[0]] - elif self.n == 3: - return x + outputs[self.layers[0]] + outputs[self.layers[1]] + x = x * w[0] + + # Fusion + nc = x.shape[1] # number of channels + for i in range(self.n - 1): + a = outputs[self.layers[i]] # feature to add + dc = nc - a.shape[1] # delta channels + + # Adjust channels + if dc > 0: # pad + pad = nn.ZeroPad2d((0, 0, 0, 0, 0, dc)) + a = pad(a) + elif dc < 0: # slice + a = a[:, :nc] + + # Sum + x = x + a * w[i + 1] if self.weight else x + a + return x class SwishImplementation(torch.autograd.Function):