updates
This commit is contained in:
parent
45ce01f859
commit
4fa0a32d05
17
models.py
17
models.py
|
@ -70,6 +70,7 @@ def create_modules(module_defs, img_size, arc):
|
||||||
layers = [int(x) for x in mdef['from'].split(',')]
|
layers = [int(x) for x in mdef['from'].split(',')]
|
||||||
filters = output_filters[layers[0]]
|
filters = output_filters[layers[0]]
|
||||||
routs.extend([i + l if l < 0 else l for l in layers])
|
routs.extend([i + l if l < 0 else l for l in layers])
|
||||||
|
# modules = weightedFeatureFusion(layers=layers)
|
||||||
|
|
||||||
elif mdef['type'] == 'reorg3d': # yolov3-spp-pan-scale
|
elif mdef['type'] == 'reorg3d': # yolov3-spp-pan-scale
|
||||||
# torch.Size([16, 128, 104, 104])
|
# torch.Size([16, 128, 104, 104])
|
||||||
|
@ -117,6 +118,21 @@ def create_modules(module_defs, img_size, arc):
|
||||||
return module_list, routs
|
return module_list, routs
|
||||||
|
|
||||||
|
|
||||||
|
class weightedFeatureFusion(nn.Module): # weighted sum of layers https://arxiv.org/abs/1911.09070
|
||||||
|
def __init__(self, layers):
|
||||||
|
super(weightedFeatureFusion, self).__init__()
|
||||||
|
self.n = len(layers) # number of layers
|
||||||
|
self.layers = layers # layer indices
|
||||||
|
self.w = torch.nn.Parameter(torch.zeros(self.n + 1)) # layer weights
|
||||||
|
|
||||||
|
def forward(self, x, outputs):
|
||||||
|
w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1)
|
||||||
|
x = x * w[0]
|
||||||
|
for i in range(self.n):
|
||||||
|
x = x + outputs[self.layers[i]] * w[i + 1]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SwishImplementation(torch.autograd.Function):
|
class SwishImplementation(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, i):
|
def forward(ctx, i):
|
||||||
|
@ -253,6 +269,7 @@ class Darknet(nn.Module):
|
||||||
x = torch.cat([layer_outputs[i] for i in layers], 1)
|
x = torch.cat([layer_outputs[i] for i in layers], 1)
|
||||||
# print(''), [print(layer_outputs[i].shape) for i in layers], print(x.shape)
|
# print(''), [print(layer_outputs[i].shape) for i in layers], print(x.shape)
|
||||||
elif mtype == 'shortcut': # sum
|
elif mtype == 'shortcut': # sum
|
||||||
|
# x = module(x, layer_outputs) # weightedFeatureFusion()
|
||||||
layers = [int(x) for x in mdef['from'].split(',')]
|
layers = [int(x) for x in mdef['from'].split(',')]
|
||||||
if verbose:
|
if verbose:
|
||||||
print('shortcut/add %s' % ([layer_outputs[i].shape for i in layers]))
|
print('shortcut/add %s' % ([layer_outputs[i].shape for i in layers]))
|
||||||
|
|
Loading…
Reference in New Issue