updates
This commit is contained in:
parent
75e8ec323f
commit
90cfb91858
25
models.py
25
models.py
|
@ -114,18 +114,31 @@ def create_modules(module_defs, img_size, arc):
|
||||||
return module_list, routs
|
return module_list, routs
|
||||||
|
|
||||||
|
|
||||||
class Swish(nn.Module):
|
class SwishImplementation(torch.autograd.Function):
|
||||||
def __init__(self):
|
@staticmethod
|
||||||
super().__init__()
|
def forward(ctx, i):
|
||||||
|
result = i * torch.sigmoid(i)
|
||||||
|
ctx.save_for_backward(i)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
i = ctx.saved_variables[0]
|
||||||
|
sigmoid_i = torch.sigmoid(i)
|
||||||
|
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEfficientSwish(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return SwishImplementation.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x.mul_(torch.sigmoid(x))
|
return x.mul_(torch.sigmoid(x))
|
||||||
|
|
||||||
|
|
||||||
class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
class Mish(nn.Module): # https://github.com/digantamisra98/Mish
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x.mul_(F.softplus(x).tanh())
|
return x.mul_(F.softplus(x).tanh())
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue