updates
This commit is contained in:
parent
3c57ff7b1b
commit
b269ed7b29
|
@ -117,15 +117,13 @@ def create_modules(module_defs, img_size, arc):
|
||||||
class SwishImplementation(torch.autograd.Function):
|
class SwishImplementation(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, i):
|
def forward(ctx, i):
|
||||||
result = i * torch.sigmoid(i)
|
|
||||||
ctx.save_for_backward(i)
|
ctx.save_for_backward(i)
|
||||||
return result
|
return i * torch.sigmoid(i)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
i = ctx.saved_variables[0]
|
sigmoid_i = torch.sigmoid(ctx.saved_variables[0])
|
||||||
sigmoid_i = torch.sigmoid(i)
|
return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i)))
|
||||||
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryEfficientSwish(nn.Module):
|
class MemoryEfficientSwish(nn.Module):
|
||||||
|
|
Loading…
Reference in New Issue