This commit is contained in:
Glenn Jocher 2019-04-23 16:48:47 +02:00
parent 334c7c94cf
commit 85a4cf0042
1 changed files with 57 additions and 31 deletions

View File

@ -172,15 +172,19 @@ class YOLOLayer(nn.Module):
class Darknet(nn.Module):
"""YOLOv3 object detection model"""
def __init__(self, cfg_path, img_size=416):
def __init__(self, cfg, img_size=(416, 416)):
super(Darknet, self).__init__()
self.module_defs = parse_model_cfg(cfg_path)
self.module_defs[0]['cfg'] = cfg_path
self.module_defs = parse_model_cfg(cfg)
self.module_defs[0]['cfg'] = cfg
self.module_defs[0]['height'] = img_size
self.hyperparams, self.module_list = create_modules(self.module_defs)
self.yolo_layers = get_yolo_layers(self)
# Needed to write header when saving weights
self.header_info = np.zeros(5, dtype=np.int32) # First five are header values
self.seen = self.header_info[3] # number of images seen during training
def forward(self, x, var=None):
img_size = max(x.shape[-2:])
layer_outputs = []
@ -270,15 +274,14 @@ def load_darknet_weights(self, weights, cutoff=-1):
cutoff = 15
# Open the weights file
fp = open(weights, 'rb')
header = np.fromfile(fp, dtype=np.int32, count=5) # First five are header values
with open(weights, 'rb') as f:
header = np.fromfile(f, dtype=np.int32, count=5) # First five are header values
# Needed to write header when saving weights
self.header_info = header
# Needed to write header when saving weights
self.header_info = header
self.seen = header[3] # number of images seen during training
weights = np.fromfile(fp, dtype=np.float32) # The rest are weights
fp.close()
self.seen = header[3] # number of images seen during training
weights = np.fromfile(f, dtype=np.float32) # The rest are weights
ptr = 0
for i, (module_def, module) in enumerate(zip(self.module_defs[:cutoff], self.module_list[:cutoff])):
@ -319,26 +322,49 @@ def load_darknet_weights(self, weights, cutoff=-1):
return cutoff
def save_weights(self, path, cutoff=-1):
fp = open(path, 'wb')
self.header_info[3] = self.seen # number of images seen during training
self.header_info.tofile(fp)
def save_weights(self, path='model.weights', cutoff=-1):
# Converts a PyTorch model to Darket format (*.pt to *.weights)
# Note: Does not work if model.fuse() is applied
with open(path, 'wb') as f:
self.header_info[3] = self.seen # number of images seen during training
self.header_info.tofile(f)
# Iterate through layers
for i, (module_def, module) in enumerate(zip(self.module_defs[:cutoff], self.module_list[:cutoff])):
if module_def['type'] == 'convolutional':
conv_layer = module[0]
# If batch norm, load bn first
if module_def['batch_normalize']:
bn_layer = module[1]
bn_layer.bias.data.cpu().numpy().tofile(fp)
bn_layer.weight.data.cpu().numpy().tofile(fp)
bn_layer.running_mean.data.cpu().numpy().tofile(fp)
bn_layer.running_var.data.cpu().numpy().tofile(fp)
# Load conv bias
else:
conv_layer.bias.data.cpu().numpy().tofile(fp)
# Load conv weights
conv_layer.weight.data.cpu().numpy().tofile(fp)
# Iterate through layers
for i, (module_def, module) in enumerate(zip(self.module_defs[:cutoff], self.module_list[:cutoff])):
if module_def['type'] == 'convolutional':
conv_layer = module[0]
# If batch norm, load bn first
if module_def['batch_normalize']:
bn_layer = module[1]
bn_layer.bias.data.cpu().numpy().tofile(f)
bn_layer.weight.data.cpu().numpy().tofile(f)
bn_layer.running_mean.data.cpu().numpy().tofile(f)
bn_layer.running_var.data.cpu().numpy().tofile(f)
# Load conv bias
else:
conv_layer.bias.data.cpu().numpy().tofile(f)
# Load conv weights
conv_layer.weight.data.cpu().numpy().tofile(f)
fp.close()
def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'):
# Converts between PyTorch and Darknet format per extension (i.e. *.weights convert to *.pt and vice versa)
# from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.weights')
# Initialize model
model = Darknet(cfg)
# Load weights and save
if weights.endswith('.pt'): # if PyTorch format
model.load_state_dict(torch.load(weights, map_location='cpu')['model'])
save_weights(model, path='converted.weights', cutoff=-1)
print("Success: converted '%s' to 'converted.weights'" % weights)
elif weights.endswith('.weights'): # darknet format
_ = load_darknet_weights(model, weights)
chkpt = {'epoch': -1, 'best_loss': None, 'model': model.state_dict(), 'optimizer': None}
torch.save(chkpt, 'converted.pt')
print("Success: converted '%s' to 'converted.pt'" % weights)
else:
print('Error: extension not supported.')