new modules and init weights

This commit is contained in:
Glenn Jocher 2020-04-14 01:20:57 -07:00
parent 76fb8d48d4
commit 835b0da68a
2 changed files with 27 additions and 2 deletions

View File

@ -3,6 +3,28 @@ import torch.nn.functional as F
from utils.utils import *
def make_divisible(v, divisor):
# Function ensures all layers have a channel number that is divisible by 8
# https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
return math.ceil(v / divisor) * divisor
class Flatten(nn.Module):
# Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions
def forward(self, x):
return x.view(x.size(0), -1)
class Concat(nn.Module):
# Concatenate a list of tensors along dimension
def __init__(self, dimension=1):
super(Concat, self).__init__()
self.d = dimension
def forward(self, x):
return torch.cat(x, self.d)
class FeatureConcat(nn.Module):
def __init__(self, layers):
super(FeatureConcat, self).__init__()

View File

@ -52,11 +52,14 @@ def time_synchronized():
def initialize_weights(model):
for m in model.modules():
if isinstance(m, nn.Conv2d):
t = type(m)
if t is nn.Conv2d:
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
elif t is nn.BatchNorm2d:
m.eps = 1e-4
m.momentum = 0.03
elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
m.inplace = True
def find_modules(model, mclass=nn.Conv2d):