diff --git a/models.py b/models.py index 1ca41cf6..2291b65f 100755 --- a/models.py +++ b/models.py @@ -120,6 +120,15 @@ class Swish(nn.Module): return x * torch.sigmoid(x) +class Mish(nn.Module): # https://github.com/digantamisra98/Mish + # Applies the mish function element-wise: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) + def __init__(self): + super().__init__() + + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + class YOLOLayer(nn.Module): def __init__(self, anchors, nc, img_size, yolo_index, arc): super(YOLOLayer, self).__init__()