This commit is contained in:
Glenn Jocher 2019-11-25 17:13:10 -10:00
parent 75e8ec323f
commit 90cfb91858
1 changed files with 19 additions and 6 deletions

View File

@ -114,18 +114,31 @@ def create_modules(module_defs, img_size, arc):
return module_list, routs return module_list, routs
class Swish(nn.Module): class SwishImplementation(torch.autograd.Function):
def __init__(self): @staticmethod
super().__init__() def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x)
class Swish(nn.Module):
def forward(self, x): def forward(self, x):
return x.mul_(torch.sigmoid(x)) return x.mul_(torch.sigmoid(x))
class Mish(nn.Module): # https://github.com/digantamisra98/Mish class Mish(nn.Module): # https://github.com/digantamisra98/Mish
def __init__(self):
super().__init__()
def forward(self, x): def forward(self, x):
return x.mul_(F.softplus(x).tanh()) return x.mul_(F.softplus(x).tanh())