updates
This commit is contained in:
parent
00862e47ef
commit
f92ad043bd
30
models.py
30
models.py
|
@ -121,24 +121,34 @@ def create_modules(module_defs, img_size, arc):
|
||||||
class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
||||||
def __init__(self, layers, weight=False):
|
def __init__(self, layers, weight=False):
|
||||||
super(weightedFeatureFusion, self).__init__()
|
super(weightedFeatureFusion, self).__init__()
|
||||||
self.n = len(layers) + 1 # number of layers
|
|
||||||
self.layers = layers # layer indices
|
self.layers = layers # layer indices
|
||||||
self.weight = weight # apply weights boolean
|
self.weight = weight # apply weights boolean
|
||||||
|
self.n = len(layers) + 1 # number of layers
|
||||||
if weight:
|
if weight:
|
||||||
self.w = torch.nn.Parameter(torch.zeros(self.n)) # layer weights
|
self.w = torch.nn.Parameter(torch.zeros(self.n)) # layer weights
|
||||||
|
|
||||||
def forward(self, x, outputs):
|
def forward(self, x, outputs):
|
||||||
|
# Weights
|
||||||
if self.weight:
|
if self.weight:
|
||||||
w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1)
|
w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1)
|
||||||
if self.n == 2:
|
x = x * w[0]
|
||||||
return x * w[0] + outputs[self.layers[0]] * w[1]
|
|
||||||
elif self.n == 3:
|
# Fusion
|
||||||
return x * w[0] + outputs[self.layers[0]] * w[1] + outputs[self.layers[1]] * w[2]
|
nc = x.shape[1] # number of channels
|
||||||
else:
|
for i in range(self.n - 1):
|
||||||
if self.n == 2:
|
a = outputs[self.layers[i]] # feature to add
|
||||||
return x + outputs[self.layers[0]]
|
dc = nc - a.shape[1] # delta channels
|
||||||
elif self.n == 3:
|
|
||||||
return x + outputs[self.layers[0]] + outputs[self.layers[1]]
|
# Adjust channels
|
||||||
|
if dc > 0: # pad
|
||||||
|
pad = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))
|
||||||
|
a = pad(a)
|
||||||
|
elif dc < 0: # slice
|
||||||
|
a = a[:, :nc]
|
||||||
|
|
||||||
|
# Sum
|
||||||
|
x = x + a * w[i + 1] if self.weight else x + a
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SwishImplementation(torch.autograd.Function):
|
class SwishImplementation(torch.autograd.Function):
|
||||||
|
|
Loading…
Reference in New Issue