imagenet normalization on layer 0 batchnorm2d()
This commit is contained in:
parent
b70cfa9a29
commit
2baf4e3f93
|
@ -52,6 +52,11 @@ def create_modules(module_defs, img_size):
|
|||
elif mdef['type'] == 'BatchNorm2d':
|
||||
filters = output_filters[-1]
|
||||
modules = nn.BatchNorm2d(filters, momentum=0.03, eps=1E-4)
|
||||
if i == 0 and filters == 3: # normalize RGB image
|
||||
# imagenet mean and var https://pytorch.org/docs/stable/torchvision/models.html#classification
|
||||
modules.running_mean = torch.tensor([0.485, 0.456, 0.406])
|
||||
modules.running_var = torch.tensor([0.0524, 0.0502, 0.0506])
|
||||
modules.momentum = 0.003
|
||||
|
||||
elif mdef['type'] == 'maxpool':
|
||||
size = mdef['size']
|
||||
|
|
Loading…
Reference in New Issue