This commit is contained in:
Glenn Jocher 2019-11-25 18:42:48 -10:00
parent 3c57ff7b1b
commit b269ed7b29
1 changed files with 3 additions and 5 deletions

View File

@ -117,15 +117,13 @@ def create_modules(module_defs, img_size, arc):
class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result
return i * torch.sigmoid(i)
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
sigmoid_i = torch.sigmoid(ctx.saved_variables[0])
return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i)))
class MemoryEfficientSwish(nn.Module):