import torch.nn.functional as F from utils.utils import * class WeightedFeatureFusion(nn.Module): # weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 def __init__(self, layers, weight=False): super(WeightedFeatureFusion, self).__init__() self.layers = layers # layer indices self.weight = weight # apply weights boolean self.n = len(layers) + 1 # number of layers if weight: self.w = torch.nn.Parameter(torch.zeros(self.n), requires_grad=True) # layer weights def forward(self, x, outputs): # Weights if self.weight: w = torch.sigmoid(self.w) * (2 / self.n) # sigmoid weights (0-1) x = x * w[0] # Fusion nx = x.shape[1] # input channels for i in range(self.n - 1): a = outputs[self.layers[i]] * w[i + 1] if self.weight else outputs[self.layers[i]] # feature to add na = a.shape[1] # feature channels # Adjust channels if nx == na: # same shape x = x + a elif nx > na: # slice input x[:, :na] = x[:, :na] + a # or a = nn.ZeroPad2d((0, 0, 0, 0, 0, dc))(a); x = x + a else: # slice feature x = x + a[:, :nx] return x class SwishImplementation(torch.autograd.Function): @staticmethod def forward(ctx, i): ctx.save_for_backward(i) return i * torch.sigmoid(i) @staticmethod def backward(ctx, grad_output): sigmoid_i = torch.sigmoid(ctx.saved_variables[0]) return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i))) class MemoryEfficientSwish(nn.Module): def forward(self, x): return SwishImplementation.apply(x) class Swish(nn.Module): def forward(self, x): return x.mul_(torch.sigmoid(x)) class Mish(nn.Module): # https://github.com/digantamisra98/Mish def forward(self, x): return x.mul_(F.softplus(x).tanh())