diff --git a/models.py b/models.py index 9bfbb114..cc3093d7 100755 --- a/models.py +++ b/models.py @@ -114,18 +114,31 @@ def create_modules(module_defs, img_size, arc): return module_list, routs -class Swish(nn.Module): - def __init__(self): - super().__init__() +class SwishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class MemoryEfficientSwish(nn.Module): + def forward(self, x): + return SwishImplementation.apply(x) + + +class Swish(nn.Module): def forward(self, x): return x.mul_(torch.sigmoid(x)) class Mish(nn.Module): # https://github.com/digantamisra98/Mish - def __init__(self): - super().__init__() - def forward(self, x): return x.mul_(F.softplus(x).tanh())