This commit is contained in:
Glenn Jocher 2020-02-18 20:13:18 -08:00
parent a971b33b74
commit b022648716
1 changed files with 9 additions and 7 deletions

View File

@ -118,19 +118,21 @@ def create_modules(module_defs, img_size, arc):
return module_list, routs return module_list, routs
class weightedFeatureFusion(nn.Module): # weighted sum of layers https://arxiv.org/abs/1911.09070 class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
def __init__(self, layers): def __init__(self, layers):
super(weightedFeatureFusion, self).__init__() super(weightedFeatureFusion, self).__init__()
self.n = len(layers) # number of layers self.n = len(layers) + 1 # number of layers
self.layers = layers # layer indices self.layers = layers # layer indices
self.w = torch.nn.Parameter(torch.zeros(self.n + 1)) # layer weights self.w = torch.nn.Parameter(torch.zeros(self.n)) # layer weights
def forward(self, x, outputs): def forward(self, x, outputs):
w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1) w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1)
x = x * w[0] if self.n == 2:
for i in range(self.n): return x * w[0] + outputs[self.layers[0]] * w[1]
x = x + outputs[self.layers[i]] * w[i + 1] elif self.n == 3:
return x return x * w[0] + outputs[self.layers[0]] * w[1] + outputs[self.layers[1]] * w[2]
else:
raise ValueError('weightedFeatureFusion() supports up to 3 layer inputs, %g attempted' % self.n)
class SwishImplementation(torch.autograd.Function): class SwishImplementation(torch.autograd.Function):