MemoryEfficientMish()

This commit is contained in:
Glenn Jocher 2020-04-27 13:51:21 -07:00
parent 3aa347a321
commit 2518868508
1 changed files with 27 additions and 7 deletions

View File

@ -98,14 +98,29 @@ class MixConv2d(nn.Module): # MixConv: Mixed Depthwise Convolutional Kernels ht
# Activation functions below -------------------------------------------------------------------------------------------
class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i * torch.sigmoid(i)
def forward(ctx, x):
ctx.save_for_backward(x)
return x * torch.sigmoid(x)
@staticmethod
def backward(ctx, grad_output):
sigmoid_i = torch.sigmoid(ctx.saved_variables[0])
return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i)))
x = ctx.saved_tensors[0]
sx = torch.sigmoid(x) # sigmoid(ctx)
return grad_output * (sx * (1 + x * (1 - sx)))
class MishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
sx = torch.sigmoid(x)
fx = F.softplus(x).tanh()
return grad_output * (fx + x * sx * (1 - fx * fx))
class MemoryEfficientSwish(nn.Module):
@ -113,9 +128,14 @@ class MemoryEfficientSwish(nn.Module):
return SwishImplementation.apply(x)
class MemoryEfficientMish(nn.Module):
def forward(self, x):
return MishImplementation.apply(x)
class Swish(nn.Module):
def forward(self, x):
return x.mul(torch.sigmoid(x))
return x * torch.sigmoid(x)
class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf
@ -125,4 +145,4 @@ class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf
class Mish(nn.Module): # https://github.com/digantamisra98/Mish
def forward(self, x):
return x.mul(F.softplus(x).tanh())
return x * F.softplus(x).tanh()