new modules and init weights
This commit is contained in:
parent
76fb8d48d4
commit
835b0da68a
|
@ -3,6 +3,28 @@ import torch.nn.functional as F
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(v, divisor):
|
||||||
|
# Function ensures all layers have a channel number that is divisible by 8
|
||||||
|
# https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||||
|
return math.ceil(v / divisor) * divisor
|
||||||
|
|
||||||
|
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
|
||||||
|
def forward(self, x):
|
||||||
|
return x.view(x.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
|
class Concat(nn.Module):
|
||||||
|
# Concatenate a list of tensors along dimension
|
||||||
|
def __init__(self, dimension=1):
|
||||||
|
super(Concat, self).__init__()
|
||||||
|
self.d = dimension
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.cat(x, self.d)
|
||||||
|
|
||||||
|
|
||||||
class FeatureConcat(nn.Module):
|
class FeatureConcat(nn.Module):
|
||||||
def __init__(self, layers):
|
def __init__(self, layers):
|
||||||
super(FeatureConcat, self).__init__()
|
super(FeatureConcat, self).__init__()
|
||||||
|
|
|
@ -52,11 +52,14 @@ def time_synchronized():
|
||||||
|
|
||||||
def initialize_weights(model):
|
def initialize_weights(model):
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
t = type(m)
|
||||||
|
if t is nn.Conv2d:
|
||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
elif t is nn.BatchNorm2d:
|
||||||
m.eps = 1e-4
|
m.eps = 1e-4
|
||||||
m.momentum = 0.03
|
m.momentum = 0.03
|
||||||
|
elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
|
||||||
|
m.inplace = True
|
||||||
|
|
||||||
|
|
||||||
def find_modules(model, mclass=nn.Conv2d):
|
def find_modules(model, mclass=nn.Conv2d):
|
||||||
|
|
Loading…
Reference in New Issue