imagenet normalization on layer 0 batchnorm2d()

This commit is contained in:
Glenn Jocher 2020-04-05 17:33:06 -07:00
parent b70cfa9a29
commit 2baf4e3f93
1 changed files with 5 additions and 0 deletions

View File

@ -52,6 +52,11 @@ def create_modules(module_defs, img_size):
elif mdef['type'] == 'BatchNorm2d':
filters = output_filters[-1]
modules = nn.BatchNorm2d(filters, momentum=0.03, eps=1E-4)
if i == 0 and filters == 3: # normalize RGB image
# imagenet mean and var https://pytorch.org/docs/stable/torchvision/models.html#classification
modules.running_mean = torch.tensor([0.485, 0.456, 0.406])
modules.running_var = torch.tensor([0.0524, 0.0502, 0.0506])
modules.momentum = 0.003
elif mdef['type'] == 'maxpool':
size = mdef['size']