new layers.py file
This commit is contained in:
parent
4ac60018f6
commit
27c7334e81
67
models.py
67
models.py
|
@ -1,8 +1,6 @@
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from utils.google_utils import *
|
from utils.google_utils import *
|
||||||
|
from utils.layers import *
|
||||||
from utils.parse_config import *
|
from utils.parse_config import *
|
||||||
from utils.utils import *
|
|
||||||
|
|
||||||
ONNX_EXPORT = False
|
ONNX_EXPORT = False
|
||||||
|
|
||||||
|
@ -70,7 +68,7 @@ def create_modules(module_defs, img_size):
|
||||||
layers = mdef['from']
|
layers = mdef['from']
|
||||||
filters = output_filters[-1]
|
filters = output_filters[-1]
|
||||||
routs.extend([i + l if l < 0 else l for l in layers])
|
routs.extend([i + l if l < 0 else l for l in layers])
|
||||||
modules = weightedFeatureFusion(layers=layers, weight='weights_type' in mdef)
|
modules = WeightedFeatureFusion(layers=layers, weight='weights_type' in mdef)
|
||||||
|
|
||||||
elif mdef['type'] == 'reorg3d': # yolov3-spp-pan-scale
|
elif mdef['type'] == 'reorg3d': # yolov3-spp-pan-scale
|
||||||
pass
|
pass
|
||||||
|
@ -111,65 +109,6 @@ def create_modules(module_defs, img_size):
|
||||||
return module_list, routs_binary
|
return module_list, routs_binary
|
||||||
|
|
||||||
|
|
||||||
class weightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
|
||||||
def __init__(self, layers, weight=False):
|
|
||||||
super(weightedFeatureFusion, self).__init__()
|
|
||||||
self.layers = layers # layer indices
|
|
||||||
self.weight = weight # apply weights boolean
|
|
||||||
self.n = len(layers) + 1 # number of layers
|
|
||||||
if weight:
|
|
||||||
self.w = torch.nn.Parameter(torch.zeros(self.n)) # layer weights
|
|
||||||
|
|
||||||
def forward(self, x, outputs):
|
|
||||||
# Weights
|
|
||||||
if self.weight:
|
|
||||||
w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1)
|
|
||||||
x = x * w[0]
|
|
||||||
|
|
||||||
# Fusion
|
|
||||||
nx = x.shape[1] # input channels
|
|
||||||
for i in range(self.n - 1):
|
|
||||||
a = outputs[self.layers[i]] * w[i + 1] if self.weight else outputs[self.layers[i]] # feature to add
|
|
||||||
na = a.shape[1] # feature channels
|
|
||||||
|
|
||||||
# Adjust channels
|
|
||||||
if nx == na: # same shape
|
|
||||||
x = x + a
|
|
||||||
elif nx > na: # slice input
|
|
||||||
x[:, :na] = x[:, :na] + a # or a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a); x = x + a
|
|
||||||
else: # slice feature
|
|
||||||
x = x + a[:, :nx]
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SwishImplementation(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, i):
|
|
||||||
ctx.save_for_backward(i)
|
|
||||||
return i * torch.sigmoid(i)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
sigmoid_i = torch.sigmoid(ctx.saved_variables[0])
|
|
||||||
return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i)))
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryEfficientSwish(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
return SwishImplementation.apply(x)
|
|
||||||
|
|
||||||
|
|
||||||
class Swish(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
return x.mul_(torch.sigmoid(x))
|
|
||||||
|
|
||||||
|
|
||||||
class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
|
||||||
def forward(self, x):
|
|
||||||
return x.mul_(F.softplus(x).tanh())
|
|
||||||
|
|
||||||
|
|
||||||
class YOLOLayer(nn.Module):
|
class YOLOLayer(nn.Module):
|
||||||
def __init__(self, anchors, nc, img_size, yolo_index, layers):
|
def __init__(self, anchors, nc, img_size, yolo_index, layers):
|
||||||
super(YOLOLayer, self).__init__()
|
super(YOLOLayer, self).__init__()
|
||||||
|
@ -277,7 +216,7 @@ class Darknet(nn.Module):
|
||||||
l = [i - 1] + module.layers # layers
|
l = [i - 1] + module.layers # layers
|
||||||
s = [list(x.shape)] + [list(out[i].shape) for i in module.layers] # shapes
|
s = [list(x.shape)] + [list(out[i].shape) for i in module.layers] # shapes
|
||||||
str = ' >> ' + ' + '.join(['layer %g %s' % x for x in zip(l, s)])
|
str = ' >> ' + ' + '.join(['layer %g %s' % x for x in zip(l, s)])
|
||||||
x = module(x, out) # weightedFeatureFusion()
|
x = module(x, out) # WeightedFeatureFusion()
|
||||||
elif mtype == 'route': # concat
|
elif mtype == 'route': # concat
|
||||||
layers = mdef['layers']
|
layers = mdef['layers']
|
||||||
if verbose:
|
if verbose:
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from utils.utils import *
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
||||||
|
def __init__(self, layers, weight=False):
|
||||||
|
super(WeightedFeatureFusion, self).__init__()
|
||||||
|
self.layers = layers # layer indices
|
||||||
|
self.weight = weight # apply weights boolean
|
||||||
|
self.n = len(layers) + 1 # number of layers
|
||||||
|
if weight:
|
||||||
|
self.w = torch.nn.Parameter(torch.zeros(self.n), requires_grad=True) # layer weights
|
||||||
|
|
||||||
|
def forward(self, x, outputs):
|
||||||
|
# Weights
|
||||||
|
if self.weight:
|
||||||
|
w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1)
|
||||||
|
x = x * w[0]
|
||||||
|
|
||||||
|
# Fusion
|
||||||
|
nx = x.shape[1] # input channels
|
||||||
|
for i in range(self.n - 1):
|
||||||
|
a = outputs[self.layers[i]] * w[i + 1] if self.weight else outputs[self.layers[i]] # feature to add
|
||||||
|
na = a.shape[1] # feature channels
|
||||||
|
|
||||||
|
# Adjust channels
|
||||||
|
if nx == na: # same shape
|
||||||
|
x = x + a
|
||||||
|
elif nx > na: # slice input
|
||||||
|
x[:, :na] = x[:, :na] + a # or a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a); x = x + a
|
||||||
|
else: # slice feature
|
||||||
|
x = x + a[:, :nx]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwishImplementation(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, i):
|
||||||
|
ctx.save_for_backward(i)
|
||||||
|
return i * torch.sigmoid(i)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
sigmoid_i = torch.sigmoid(ctx.saved_variables[0])
|
||||||
|
return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i)))
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEfficientSwish(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return SwishImplementation.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x.mul_(torch.sigmoid(x))
|
||||||
|
|
||||||
|
|
||||||
|
class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
||||||
|
def forward(self, x):
|
||||||
|
return x.mul_(F.softplus(x).tanh())
|
Loading…
Reference in New Issue