This commit is contained in:
Glenn Jocher 2019-06-05 13:49:56 +02:00
parent c46e156ff8
commit d7a28bd9f7
1 changed files with 10 additions and 11 deletions

View File

@ -181,9 +181,9 @@ class Darknet(nn.Module):
self.hyperparams, self.module_list = create_modules(self.module_defs)
self.yolo_layers = get_yolo_layers(self)
# Needed to write header when saving weights
self.header_info = np.zeros(5, dtype=np.int32) # First five are header values
self.seen = self.header_info[3] # number of images seen during training
# Darknet Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
self.version = np.array([0, 2, 5], dtype=np.int32) # (int32) version info: major, minor, revision
self.seen = np.array([0], dtype=np.int64) # (int64) number of images seen during training
def forward(self, x, var=None):
img_size = max(x.shape[-2:])
@ -274,14 +274,12 @@ def load_darknet_weights(self, weights, cutoff=-1):
elif weights_file == 'yolov3-tiny.conv.15':
cutoff = 15
# Open the weights file
# Read weights file
with open(weights, 'rb') as f:
header = np.fromfile(f, dtype=np.int32, count=5) # First five are header values
# Read Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
self.version = np.fromfile(f, dtype=np.int32, count=3) # (int32) version info: major, minor, revision
self.seen = np.fromfile(f, dtype=np.int64, count=1) # (int64) number of images seen during training
# Needed to write header when saving weights
self.header_info = header
self.seen = header[3] # number of images seen during training
weights = np.fromfile(f, dtype=np.float32) # The rest are weights
ptr = 0
@ -327,8 +325,9 @@ def save_weights(self, path='model.weights', cutoff=-1):
# Converts a PyTorch model to Darket format (*.pt to *.weights)
# Note: Does not work if model.fuse() is applied
with open(path, 'wb') as f:
self.header_info[3] = self.seen # number of images seen during training
self.header_info.tofile(f)
# Write Header https://github.com/AlexeyAB/darknet/issues/2914#issuecomment-496675346
self.version.tofile(f) # (int32) version info: major, minor, revision
self.seen.tofile(f) # (int64) number of images seen during training
# Iterate through layers
for i, (module_def, module) in enumerate(zip(self.module_defs[:cutoff], self.module_list[:cutoff])):