diff --git a/models.py b/models.py index 683c3a7f..4ecf99e6 100755 --- a/models.py +++ b/models.py @@ -225,6 +225,7 @@ class Darknet(nn.Module): self.module_defs = parse_model_cfg(cfg) self.module_list, self.routs = create_modules(self.module_defs, img_size) self.yolo_layers = get_yolo_layers(self) + # torch_utils.initialize_weights(self) # Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346 self.version = np.array([0, 2, 5], dtype=np.int32) # (int32) version info: major, minor, revision diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 5819e68e..f63bc110 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -50,6 +50,15 @@ def time_synchronized(): return time.time() +def initialize_weights(model): + for m in model.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def fuse_conv_and_bn(conv, bn): # https://tehnokv.com/posts/fusing-batchnorm-and-conv/ with torch.no_grad(): diff --git a/utils/utils.py b/utils/utils.py index ab121c64..60443b0d 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -93,15 +93,6 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper) return x -def weights_init_normal(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - torch.nn.init.normal_(m.weight.data, 0.0, 0.03) - elif classname.find('BatchNorm2d') != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.03) - torch.nn.init.constant_(m.bias.data, 0.0) - - def xyxy2xywh(x): # Transform box coordinates from [x1, y1, x2, y2] (where xy1=top-left, xy2=bottom-right) to [x, y, w, h] y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)