diff --git a/utils/layers.py b/utils/layers.py index a3630020..81d3408c 100644 --- a/utils/layers.py +++ b/utils/layers.py @@ -118,6 +118,11 @@ class Swish(nn.Module): return x.mul(torch.sigmoid(x)) +class HardSwish(nn.Module): # https://arxiv.org/pdf/1905.02244.pdf + def forward(self, x): + return x * F.hardtanh(x + 3, 0., 6., True) / 6. + + class Mish(nn.Module): # https://github.com/digantamisra98/Mish def forward(self, x): return x.mul(F.softplus(x).tanh())