updates
This commit is contained in:
parent
c46e156ff8
commit
d7a28bd9f7
21
models.py
21
models.py
|
@ -181,9 +181,9 @@ class Darknet(nn.Module):
|
|||
self.hyperparams, self.module_list = create_modules(self.module_defs)
|
||||
self.yolo_layers = get_yolo_layers(self)
|
||||
|
||||
# Needed to write header when saving weights
|
||||
self.header_info = np.zeros(5, dtype=np.int32) # First five are header values
|
||||
self.seen = self.header_info[3] # number of images seen during training
|
||||
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
|
||||
self.version = np.array([0, 2, 5], dtype=np.int32) # (int32) version info: major, minor, revision
|
||||
self.seen = np.array([0], dtype=np.int64) # (int64) number of images seen during training
|
||||
|
||||
def forward(self, x, var=None):
|
||||
img_size = max(x.shape[-2:])
|
||||
|
@ -274,14 +274,12 @@ def load_darknet_weights(self, weights, cutoff=-1):
|
|||
elif weights_file == 'yolov3-tiny.conv.15':
|
||||
cutoff = 15
|
||||
|
||||
# Open the weights file
|
||||
# Read weights file
|
||||
with open(weights, 'rb') as f:
|
||||
header = np.fromfile(f, dtype=np.int32, count=5) # First five are header values
|
||||
# Read Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
|
||||
self.version = np.fromfile(f, dtype=np.int32, count=3) # (int32) version info: major, minor, revision
|
||||
self.seen = np.fromfile(f, dtype=np.int64, count=1) # (int64) number of images seen during training
|
||||
|
||||
# Needed to write header when saving weights
|
||||
self.header_info = header
|
||||
|
||||
self.seen = header[3] # number of images seen during training
|
||||
weights = np.fromfile(f, dtype=np.float32) # The rest are weights
|
||||
|
||||
ptr = 0
|
||||
|
@ -327,8 +325,9 @@ def save_weights(self, path='model.weights', cutoff=-1):
|
|||
# Converts a PyTorch model to Darket format (*.pt to *.weights)
|
||||
# Note: Does not work if model.fuse() is applied
|
||||
with open(path, 'wb') as f:
|
||||
self.header_info[3] = self.seen # number of images seen during training
|
||||
self.header_info.tofile(f)
|
||||
# Write Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
|
||||
self.version.tofile(f) # (int32) version info: major, minor, revision
|
||||
self.seen.tofile(f) # (int64) number of images seen during training
|
||||
|
||||
# Iterate through layers
|
||||
for i, (module_def, module) in enumerate(zip(self.module_defs[:cutoff], self.module_list[:cutoff])):
|
||||
|
|
Loading…
Reference in New Issue