diff --git a/utils/layers.py b/utils/layers.py index 81d3408c..35c13c9f 100644 --- a/utils/layers.py +++ b/utils/layers.py @@ -98,14 +98,29 @@ class MixConv2d(nn.Module): # MixConv: Mixed Depthwise Convolutional Kernels ht # Activation functions below ------------------------------------------------------------------------------------------- class SwishImplementation(torch.autograd.Function): @staticmethod - def forward(ctx, i): - ctx.save_for_backward(i) - return i * torch.sigmoid(i) + def forward(ctx, x): + ctx.save_for_backward(x) + return x * torch.sigmoid(x) @staticmethod def backward(ctx, grad_output): - sigmoid_i = torch.sigmoid(ctx.saved_variables[0]) - return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i))) + x = ctx.saved_tensors[0] + sx = torch.sigmoid(x) # sigmoid(ctx) + return grad_output * (sx * (1 + x * (1 - sx))) + + +class MishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + sx = torch.sigmoid(x) + fx = F.softplus(x).tanh() + return grad_output * (fx + x * sx * (1 - fx * fx)) class MemoryEfficientSwish(nn.Module): @@ -113,9 +128,14 @@ class MemoryEfficientSwish(nn.Module): return SwishImplementation.apply(x) +class MemoryEfficientMish(nn.Module): + def forward(self, x): + return MishImplementation.apply(x) + + class Swish(nn.Module): def forward(self, x): - return x.mul(torch.sigmoid(x)) + return x * torch.sigmoid(x) class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf @@ -125,4 +145,4 @@ class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf class Mish(nn.Module): # https://github.com/digantamisra98/Mish def forward(self, x): - return x.mul(F.softplus(x).tanh()) + return x * F.softplus(x).tanh()