This commit is contained in:
Glenn Jocher 2020-02-19 16:05:57 -08:00
parent 00862e47ef
commit f92ad043bd
1 changed files with 20 additions and 10 deletions

View File

@ -121,24 +121,34 @@ def create_modules(module_defs, img_size, arc):
class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
def __init__(self, layers, weight=False): def __init__(self, layers, weight=False):
super(weightedFeatureFusion, self).__init__() super(weightedFeatureFusion, self).__init__()
self.n = len(layers) + 1 # number of layers
self.layers = layers # layer indices self.layers = layers # layer indices
self.weight = weight # apply weights boolean self.weight = weight # apply weights boolean
self.n = len(layers) + 1 # number of layers
if weight: if weight:
self.w = torch.nn.Parameter(torch.zeros(self.n)) # layer weights self.w = torch.nn.Parameter(torch.zeros(self.n)) # layer weights
def forward(self, x, outputs): def forward(self, x, outputs):
# Weights
if self.weight: if self.weight:
w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1) w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1)
if self.n == 2: x = x * w[0]
return x * w[0] + outputs[self.layers[0]] * w[1]
elif self.n == 3: # Fusion
return x * w[0] + outputs[self.layers[0]] * w[1] + outputs[self.layers[1]] * w[2] nc = x.shape[1] # number of channels
else: for i in range(self.n - 1):
if self.n == 2: a = outputs[self.layers[i]] # feature to add
return x + outputs[self.layers[0]] dc = nc - a.shape[1] # delta channels
elif self.n == 3:
return x + outputs[self.layers[0]] + outputs[self.layers[1]] # Adjust channels
if dc > 0: # pad
pad = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))
a = pad(a)
elif dc < 0: # slice
a = a[:, :nc]
# Sum
x = x + a * w[i + 1] if self.weight else x + a
return x
class SwishImplementation(torch.autograd.Function): class SwishImplementation(torch.autograd.Function):