MemoryEfficientMish()
This commit is contained in:
parent
3aa347a321
commit
2518868508
|
@ -98,14 +98,29 @@ class MixConv2d(nn.Module): # MixConv: Mixed Depthwise Convolutional Kernels ht
|
|||
# Activation functions below -------------------------------------------------------------------------------------------
|
||||
class SwishImplementation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, i):
|
||||
ctx.save_for_backward(i)
|
||||
return i * torch.sigmoid(i)
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
sigmoid_i = torch.sigmoid(ctx.saved_variables[0])
|
||||
return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i)))
|
||||
x = ctx.saved_tensors[0]
|
||||
sx = torch.sigmoid(x) # sigmoid(ctx)
|
||||
return grad_output * (sx * (1 + x * (1 - sx)))
|
||||
|
||||
|
||||
class MishImplementation(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
ctx.save_for_backward(x)
|
||||
return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x = ctx.saved_tensors[0]
|
||||
sx = torch.sigmoid(x)
|
||||
fx = F.softplus(x).tanh()
|
||||
return grad_output * (fx + x * sx * (1 - fx * fx))
|
||||
|
||||
|
||||
class MemoryEfficientSwish(nn.Module):
|
||||
|
@ -113,9 +128,14 @@ class MemoryEfficientSwish(nn.Module):
|
|||
return SwishImplementation.apply(x)
|
||||
|
||||
|
||||
class MemoryEfficientMish(nn.Module):
|
||||
def forward(self, x):
|
||||
return MishImplementation.apply(x)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
def forward(self, x):
|
||||
return x.mul(torch.sigmoid(x))
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf
|
||||
|
@ -125,4 +145,4 @@ class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf
|
|||
|
||||
class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
||||
def forward(self, x):
|
||||
return x.mul(F.softplus(x).tanh())
|
||||
return x * F.softplus(x).tanh()
|
||||
|
|
Loading…
Reference in New Issue