kaiming weight init
This commit is contained in:
parent
2cf23c4aee
commit
58edfc4a84
|
@ -225,6 +225,7 @@ class Darknet(nn.Module):
|
||||||
self.module_defs = parse_model_cfg(cfg)
|
self.module_defs = parse_model_cfg(cfg)
|
||||||
self.module_list, self.routs = create_modules(self.module_defs, img_size)
|
self.module_list, self.routs = create_modules(self.module_defs, img_size)
|
||||||
self.yolo_layers = get_yolo_layers(self)
|
self.yolo_layers = get_yolo_layers(self)
|
||||||
|
# torch_utils.initialize_weights(self)
|
||||||
|
|
||||||
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
|
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
|
||||||
self.version = np.array([0, 2, 5], dtype=np.int32) # (int32) version info: major, minor, revision
|
self.version = np.array([0, 2, 5], dtype=np.int32) # (int32) version info: major, minor, revision
|
||||||
|
|
|
@ -50,6 +50,15 @@ def time_synchronized():
|
||||||
return time.time()
|
return time.time()
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_weights(model):
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
def fuse_conv_and_bn(conv, bn):
|
def fuse_conv_and_bn(conv, bn):
|
||||||
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
@ -93,15 +93,6 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def weights_init_normal(m):
|
|
||||||
classname = m.__class__.__name__
|
|
||||||
if classname.find('Conv') != -1:
|
|
||||||
torch.nn.init.normal_(m.weight.data, 0.0, 0.03)
|
|
||||||
elif classname.find('BatchNorm2d') != -1:
|
|
||||||
torch.nn.init.normal_(m.weight.data, 1.0, 0.03)
|
|
||||||
torch.nn.init.constant_(m.bias.data, 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
def xyxy2xywh(x):
|
def xyxy2xywh(x):
|
||||||
# Transform box coordinates from [x1, y1, x2, y2] (where xy1=top-left, xy2=bottom-right) to [x, y, w, h]
|
# Transform box coordinates from [x1, y1, x2, y2] (where xy1=top-left, xy2=bottom-right) to [x, y, w, h]
|
||||||
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
|
y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
|
||||||
|
|
Loading…
Reference in New Issue