new modules and init weights
This commit is contained in:
		
							parent
							
								
									76fb8d48d4
								
							
						
					
					
						commit
						835b0da68a
					
				|  | @ -3,6 +3,28 @@ import torch.nn.functional as F | ||||||
| from utils.utils import * | from utils.utils import * | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def make_divisible(v, divisor): | ||||||
|  |     # Function ensures all layers have a channel number that is divisible by 8 | ||||||
|  |     # https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py | ||||||
|  |     return math.ceil(v / divisor) * divisor | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Flatten(nn.Module): | ||||||
|  |     # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions | ||||||
|  |     def forward(self, x): | ||||||
|  |         return x.view(x.size(0), -1) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Concat(nn.Module): | ||||||
|  |     # Concatenate a list of tensors along dimension | ||||||
|  |     def __init__(self, dimension=1): | ||||||
|  |         super(Concat, self).__init__() | ||||||
|  |         self.d = dimension | ||||||
|  | 
 | ||||||
|  |     def forward(self, x): | ||||||
|  |         return torch.cat(x, self.d) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class FeatureConcat(nn.Module): | class FeatureConcat(nn.Module): | ||||||
|     def __init__(self, layers): |     def __init__(self, layers): | ||||||
|         super(FeatureConcat, self).__init__() |         super(FeatureConcat, self).__init__() | ||||||
|  |  | ||||||
|  | @ -52,11 +52,14 @@ def time_synchronized(): | ||||||
| 
 | 
 | ||||||
| def initialize_weights(model): | def initialize_weights(model): | ||||||
|     for m in model.modules(): |     for m in model.modules(): | ||||||
|         if isinstance(m, nn.Conv2d): |         t = type(m) | ||||||
|  |         if t is nn.Conv2d: | ||||||
|             nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |             nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||||||
|         elif isinstance(m, nn.BatchNorm2d): |         elif t is nn.BatchNorm2d: | ||||||
|             m.eps = 1e-4 |             m.eps = 1e-4 | ||||||
|             m.momentum = 0.03 |             m.momentum = 0.03 | ||||||
|  |         elif t in [nn.LeakyReLU, nn.ReLU, nn.ReLU6]: | ||||||
|  |             m.inplace = True | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def find_modules(model, mclass=nn.Conv2d): | def find_modules(model, mclass=nn.Conv2d): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue