diff --git a/utils/layers.py b/utils/layers.py index dce3fed2..7662c6bd 100644 --- a/utils/layers.py +++ b/utils/layers.py @@ -3,6 +3,28 @@ import torch.nn.functional as F from utils.utils import * +def make_divisible(v, divisor): + # Function ensures all layers have a channel number that is divisible by 8 + # https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + return math.ceil(v / divisor) * divisor + + +class Flatten(nn.Module): + # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions + def forward(self, x): + return x.view(x.size(0), -1) + + +class Concat(nn.Module): + # Concatenate a list of tensors along dimension + def __init__(self, dimension=1): + super(Concat, self).__init__() + self.d = dimension + + def forward(self, x): + return torch.cat(x, self.d) + + class FeatureConcat(nn.Module): def __init__(self, layers): super(FeatureConcat, self).__init__() diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 0e1ade3a..e5286cd0 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -52,11 +52,14 @@ def time_synchronized(): def initialize_weights(model): for m in model.modules(): - if isinstance(m, nn.Conv2d): + t = type(m) + if t is nn.Conv2d: nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): + elif t is nn.BatchNorm2d: m.eps = 1e-4 m.momentum = 0.03 + elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]: + m.inplace = True def find_modules(model, mclass=nn.Conv2d):