add Mish() support
This commit is contained in:
parent
18d4ebfd12
commit
a0a3bab9e6
|
@ -115,9 +115,9 @@ class MemoryEfficientSwish(nn.Module):
|
||||||
|
|
||||||
class Swish(nn.Module):
|
class Swish(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x.mul_(torch.sigmoid(x))
|
return x.mul(torch.sigmoid(x))
|
||||||
|
|
||||||
|
|
||||||
class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x.mul_(F.softplus(x).tanh())
|
return x.mul(F.softplus(x).tanh())
|
||||||
|
|
Loading…
Reference in New Issue