From b269ed7b2975ee4f646819a78bf8c20771089e29 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 25 Nov 2019 18:42:48 -1000 Subject: [PATCH] updates --- models.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/models.py b/models.py index cc3093d7..c75b1ec1 100755 --- a/models.py +++ b/models.py @@ -117,15 +117,13 @@ def create_modules(module_defs, img_size, arc): class SwishImplementation(torch.autograd.Function): @staticmethod def forward(ctx, i): - result = i * torch.sigmoid(i) ctx.save_for_backward(i) - return result + return i * torch.sigmoid(i) @staticmethod def backward(ctx, grad_output): - i = ctx.saved_variables[0] - sigmoid_i = torch.sigmoid(i) - return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + sigmoid_i = torch.sigmoid(ctx.saved_variables[0]) + return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i))) class MemoryEfficientSwish(nn.Module):