MemoryEfficientMish()

This commit is contained in:
Glenn Jocher 2020-04-27 13:51:21 -07:00
parent 3aa347a321
commit 2518868508
1 changed files with 27 additions and 7 deletions

View File

@ -98,14 +98,29 @@ class MixConv2d(nn.Module): # MixConv: Mixed Depthwise Convolutional Kernels ht
# Activation functions below ------------------------------------------------------------------------------------------- # Activation functions below -------------------------------------------------------------------------------------------
class SwishImplementation(torch.autograd.Function): class SwishImplementation(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, i): def forward(ctx, x):
ctx.save_for_backward(i) ctx.save_for_backward(x)
return i * torch.sigmoid(i) return x * torch.sigmoid(x)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
sigmoid_i = torch.sigmoid(ctx.saved_variables[0]) x = ctx.saved_tensors[0]
return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i))) sx = torch.sigmoid(x) # sigmoid(ctx)
return grad_output * (sx * (1 + x * (1 - sx)))
class MishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
sx = torch.sigmoid(x)
fx = F.softplus(x).tanh()
return grad_output * (fx + x * sx * (1 - fx * fx))
class MemoryEfficientSwish(nn.Module): class MemoryEfficientSwish(nn.Module):
@ -113,9 +128,14 @@ class MemoryEfficientSwish(nn.Module):
return SwishImplementation.apply(x) return SwishImplementation.apply(x)
class MemoryEfficientMish(nn.Module):
def forward(self, x):
return MishImplementation.apply(x)
class Swish(nn.Module): class Swish(nn.Module):
def forward(self, x): def forward(self, x):
return x.mul(torch.sigmoid(x)) return x * torch.sigmoid(x)
class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf
@ -125,4 +145,4 @@ class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf
class Mish(nn.Module): # https://github.com/digantamisra98/Mish class Mish(nn.Module): # https://github.com/digantamisra98/Mish
def forward(self, x): def forward(self, x):
return x.mul(F.softplus(x).tanh()) return x * F.softplus(x).tanh()