new modules and init weights
This commit is contained in:
parent
76fb8d48d4
commit
835b0da68a
|
@ -3,6 +3,28 @@ import torch.nn.functional as F
|
|||
from utils.utils import *
|
||||
|
||||
|
||||
def make_divisible(v, divisor):
|
||||
# Function ensures all layers have a channel number that is divisible by 8
|
||||
# https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
return math.ceil(v / divisor) * divisor
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
class Concat(nn.Module):
|
||||
# Concatenate a list of tensors along dimension
|
||||
def __init__(self, dimension=1):
|
||||
super(Concat, self).__init__()
|
||||
self.d = dimension
|
||||
|
||||
def forward(self, x):
|
||||
return torch.cat(x, self.d)
|
||||
|
||||
|
||||
class FeatureConcat(nn.Module):
|
||||
def __init__(self, layers):
|
||||
super(FeatureConcat, self).__init__()
|
||||
|
|
|
@ -52,11 +52,14 @@ def time_synchronized():
|
|||
|
||||
def initialize_weights(model):
|
||||
for m in model.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
t = type(m)
|
||||
if t is nn.Conv2d:
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
elif t is nn.BatchNorm2d:
|
||||
m.eps = 1e-4
|
||||
m.momentum = 0.03
|
||||
elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
|
||||
m.inplace = True
|
||||
|
||||
|
||||
def find_modules(model, mclass=nn.Conv2d):
|
||||
|
|
Loading…
Reference in New Issue