kaiming weight init

This commit is contained in:
Glenn Jocher 2020-04-11 10:45:33 -07:00
parent 2cf23c4aee
commit 58edfc4a84
3 changed files with 10 additions and 9 deletions

View File

@ -225,6 +225,7 @@ class Darknet(nn.Module):
self.module_defs = parse_model_cfg(cfg) self.module_defs = parse_model_cfg(cfg)
self.module_list, self.routs = create_modules(self.module_defs, img_size) self.module_list, self.routs = create_modules(self.module_defs, img_size)
self.yolo_layers = get_yolo_layers(self) self.yolo_layers = get_yolo_layers(self)
# torch_utils.initialize_weights(self)
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346 # Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
self.version = np.array([0, 2, 5], dtype=np.int32) # (int32) version info: major, minor, revision self.version = np.array([0, 2, 5], dtype=np.int32) # (int32) version info: major, minor, revision

View File

@ -50,6 +50,15 @@ def time_synchronized():
return time.time() return time.time()
def initialize_weights(model):
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def fuse_conv_and_bn(conv, bn): def fuse_conv_and_bn(conv, bn):
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/ # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
with torch.no_grad(): with torch.no_grad():

View File

@ -93,15 +93,6 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
return x return x
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.03)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.03)
torch.nn.init.constant_(m.bias.data, 0.0)
def xyxy2xywh(x): def xyxy2xywh(x):
# Transform box coordinates from [x1, y1, x2, y2] (where xy1=top-left, xy2=bottom-right) to [x, y, w, h] # Transform box coordinates from [x1, y1, x2, y2] (where xy1=top-left, xy2=bottom-right) to [x, y, w, h]
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x) y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)