From b022648716f5dbb0549747357928c0d865da2a1b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 18 Feb 2020 20:13:18 -0800 Subject: [PATCH] updates --- models.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/models.py b/models.py index 2c55c74c..abfbf98a 100755 --- a/models.py +++ b/models.py @@ -118,19 +118,21 @@ def create_modules(module_defs, img_size, arc): return module_list, routs -class weightedFeatureFusion(nn.Module): # weighted sum of layers https://arxiv.org/abs/1911.09070 +class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 def __init__(self, layers): super(weightedFeatureFusion, self).__init__() - self.n = len(layers) # number of layers + self.n = len(layers) + 1 # number of layers self.layers = layers # layer indices - self.w = torch.nn.Parameter(torch.zeros(self.n + 1)) # layer weights + self.w = torch.nn.Parameter(torch.zeros(self.n)) # layer weights def forward(self, x, outputs): w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1) - x = x * w[0] - for i in range(self.n): - x = x + outputs[self.layers[i]] * w[i + 1] - return x + if self.n == 2: + return x * w[0] + outputs[self.layers[0]] * w[1] + elif self.n == 3: + return x * w[0] + outputs[self.layers[0]] * w[1] + outputs[self.layers[1]] * w[2] + else: + raise ValueError('weightedFeatureFusion() supports up to 3 layer inputs, %g attempted' % self.n) class SwishImplementation(torch.autograd.Function):