2020-04-02 19:22:15 +00:00
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from utils.utils import *
|
|
|
|
|
|
|
|
|
2020-04-14 08:20:57 +00:00
|
|
|
def make_divisible(v, divisor):
|
|
|
|
# Function ensures all layers have a channel number that is divisible by 8
|
|
|
|
# https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
|
|
|
return math.ceil(v / divisor) * divisor
|
|
|
|
|
|
|
|
|
|
|
|
class Flatten(nn.Module):
|
|
|
|
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
|
|
|
|
def forward(self, x):
|
|
|
|
return x.view(x.size(0), -1)
|
|
|
|
|
|
|
|
|
|
|
|
class Concat(nn.Module):
|
|
|
|
# Concatenate a list of tensors along dimension
|
|
|
|
def __init__(self, dimension=1):
|
|
|
|
super(Concat, self).__init__()
|
|
|
|
self.d = dimension
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return torch.cat(x, self.d)
|
|
|
|
|
|
|
|
|
2020-04-05 21:47:41 +00:00
|
|
|
class FeatureConcat(nn.Module):
|
|
|
|
def __init__(self, layers):
|
|
|
|
super(FeatureConcat, self).__init__()
|
|
|
|
self.layers = layers # layer indices
|
|
|
|
self.multiple = len(layers) > 1 # multiple layers flag
|
|
|
|
|
|
|
|
def forward(self, x, outputs):
|
|
|
|
return torch.cat([outputs[i] for i in self.layers], 1) if self.multiple else outputs[self.layers[0]]
|
|
|
|
|
|
|
|
|
2020-04-02 19:22:15 +00:00
|
|
|
class WeightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
|
|
|
|
def __init__(self, layers, weight=False):
|
|
|
|
super(WeightedFeatureFusion, self).__init__()
|
|
|
|
self.layers = layers # layer indices
|
|
|
|
self.weight = weight # apply weights boolean
|
|
|
|
self.n = len(layers) + 1 # number of layers
|
|
|
|
if weight:
|
2020-04-16 23:12:23 +00:00
|
|
|
self.w = nn.Parameter(torch.zeros(self.n), requires_grad=True) # layer weights
|
2020-04-02 19:22:15 +00:00
|
|
|
|
|
|
|
def forward(self, x, outputs):
|
|
|
|
# Weights
|
|
|
|
if self.weight:
|
|
|
|
w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1)
|
|
|
|
x = x * w[0]
|
|
|
|
|
|
|
|
# Fusion
|
|
|
|
nx = x.shape[1] # input channels
|
|
|
|
for i in range(self.n - 1):
|
|
|
|
a = outputs[self.layers[i]] * w[i + 1] if self.weight else outputs[self.layers[i]] # feature to add
|
|
|
|
na = a.shape[1] # feature channels
|
|
|
|
|
|
|
|
# Adjust channels
|
|
|
|
if nx == na: # same shape
|
|
|
|
x = x + a
|
|
|
|
elif nx > na: # slice input
|
|
|
|
x[:, :na] = x[:, :na] + a # or a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a); x = x + a
|
|
|
|
else: # slice feature
|
|
|
|
x = x + a[:, :nx]
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
2020-04-04 03:03:44 +00:00
|
|
|
class MixConv2d(nn.Module): # MixConv: Mixed Depthwise Convolutional Kernels https://arxiv.org/abs/1907.09595
|
|
|
|
def __init__(self, in_ch, out_ch, k=(3, 5, 7), stride=1, dilation=1, bias=True, method='equal_params'):
|
|
|
|
super(MixConv2d, self).__init__()
|
|
|
|
|
|
|
|
groups = len(k)
|
|
|
|
if method == 'equal_ch': # equal channels per group
|
|
|
|
i = torch.linspace(0, groups - 1E-6, out_ch).floor() # out_ch indices
|
|
|
|
ch = [(i == g).sum() for g in range(groups)]
|
|
|
|
else: # 'equal_params': equal parameter count per group
|
|
|
|
b = [out_ch] + [0] * groups
|
|
|
|
a = np.eye(groups + 1, groups, k=-1)
|
|
|
|
a -= np.roll(a, 1, axis=1)
|
|
|
|
a *= np.array(k) ** 2
|
|
|
|
a[0] = 1
|
|
|
|
ch = np.linalg.lstsq(a, b, rcond=None)[0].round().astype(int) # solve for equal weight indices, ax = b
|
|
|
|
|
2020-04-16 23:12:23 +00:00
|
|
|
self.m = nn.ModuleList([nn.Conv2d(in_channels=in_ch,
|
|
|
|
out_channels=ch[g],
|
|
|
|
kernel_size=k[g],
|
|
|
|
stride=stride,
|
|
|
|
padding=k[g] // 2, # 'same' pad
|
|
|
|
dilation=dilation,
|
|
|
|
bias=bias) for g in range(groups)])
|
2020-04-04 03:03:44 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return torch.cat([m(x) for m in self.m], 1)
|
|
|
|
|
|
|
|
|
|
|
|
# Activation functions below -------------------------------------------------------------------------------------------
|
2020-04-02 19:22:15 +00:00
|
|
|
class SwishImplementation(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, i):
|
|
|
|
ctx.save_for_backward(i)
|
|
|
|
return i * torch.sigmoid(i)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_output):
|
|
|
|
sigmoid_i = torch.sigmoid(ctx.saved_variables[0])
|
|
|
|
return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i)))
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryEfficientSwish(nn.Module):
|
|
|
|
def forward(self, x):
|
|
|
|
return SwishImplementation.apply(x)
|
|
|
|
|
|
|
|
|
|
|
|
class Swish(nn.Module):
|
|
|
|
def forward(self, x):
|
2020-04-26 23:31:21 +00:00
|
|
|
return x.mul(torch.sigmoid(x))
|
2020-04-02 19:22:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
|
|
|
def forward(self, x):
|
2020-04-26 23:31:21 +00:00
|
|
|
return x.mul(F.softplus(x).tanh())
|