add HardSwish()

This commit is contained in:
Glenn Jocher 2020-04-27 13:08:24 -07:00
parent 692f945819
commit 3aa347a321
1 changed files with 5 additions and 0 deletions

View File

@ -118,6 +118,11 @@ class Swish(nn.Module):
return x.mul(torch.sigmoid(x)) return x.mul(torch.sigmoid(x))
class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf
def forward(self, x):
return x * F.hardtanh(x + 3, 0., 6., True) / 6.
class Mish(nn.Module): # https://github.com/digantamisra98/Mish class Mish(nn.Module): # https://github.com/digantamisra98/Mish
def forward(self, x): def forward(self, x):
return x.mul(F.softplus(x).tanh()) return x.mul(F.softplus(x).tanh())