diff --git a/models.py b/models.py index c4a29aee..3d6b3ad6 100755 --- a/models.py +++ b/models.py @@ -52,6 +52,11 @@ def create_modules(module_defs, img_size): elif mdef['type'] == 'BatchNorm2d': filters = output_filters[-1] modules = nn.BatchNorm2d(filters, momentum=0.03, eps=1E-4) + if i == 0 and filters == 3: # normalize RGB image + # imagenet mean and var https://pytorch.org/docs/stable/torchvision/models.html#classification + modules.running_mean = torch.tensor([0.485, 0.456, 0.406]) + modules.running_var = torch.tensor([0.0524, 0.0502, 0.0506]) + modules.momentum = 0.003 elif mdef['type'] == 'maxpool': size = mdef['size']