diff --git a/models.py b/models.py index d352e873..ff3138af 100755 --- a/models.py +++ b/models.py @@ -118,16 +118,15 @@ class Swish(nn.Module): super(Swish, self).__init__() def forward(self, x): - return x * torch.sigmoid(x) + return x.mul_(torch.sigmoid(x)) class Mish(nn.Module): # https://github.com/digantamisra98/Mish - # Applies the mish function element-wise: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) def __init__(self): super().__init__() def forward(self, x): - return x * torch.tanh(F.softplus(x)) + return x.mul_(F.softplus(x).tanh()) class YOLOLayer(nn.Module):