MemoryEfficientMish()
This commit is contained in:
parent
3aa347a321
commit
2518868508
|
@ -98,14 +98,29 @@ class MixConv2d(nn.Module): # MixConv: Mixed Depthwise Convolutional Kernels ht
|
||||||
# Activation functions below -------------------------------------------------------------------------------------------
|
# Activation functions below -------------------------------------------------------------------------------------------
|
||||||
class SwishImplementation(torch.autograd.Function):
|
class SwishImplementation(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, i):
|
def forward(ctx, x):
|
||||||
ctx.save_for_backward(i)
|
ctx.save_for_backward(x)
|
||||||
return i * torch.sigmoid(i)
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
sigmoid_i = torch.sigmoid(ctx.saved_variables[0])
|
x = ctx.saved_tensors[0]
|
||||||
return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i)))
|
sx = torch.sigmoid(x) # sigmoid(ctx)
|
||||||
|
return grad_output * (sx * (1 + x * (1 - sx)))
|
||||||
|
|
||||||
|
|
||||||
|
class MishImplementation(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
x = ctx.saved_tensors[0]
|
||||||
|
sx = torch.sigmoid(x)
|
||||||
|
fx = F.softplus(x).tanh()
|
||||||
|
return grad_output * (fx + x * sx * (1 - fx * fx))
|
||||||
|
|
||||||
|
|
||||||
class MemoryEfficientSwish(nn.Module):
|
class MemoryEfficientSwish(nn.Module):
|
||||||
|
@ -113,9 +128,14 @@ class MemoryEfficientSwish(nn.Module):
|
||||||
return SwishImplementation.apply(x)
|
return SwishImplementation.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEfficientMish(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return MishImplementation.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class Swish(nn.Module):
|
class Swish(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x.mul(torch.sigmoid(x))
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf
|
class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf
|
||||||
|
@ -125,4 +145,4 @@ class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf
|
||||||
|
|
||||||
class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x.mul(F.softplus(x).tanh())
|
return x * F.softplus(x).tanh()
|
||||||
|
|
Loading…
Reference in New Issue