convert(...) changed to save converted file alongside the original file (#1167)

This commit is contained in:
IlyaOvodov 2020-05-13 19:08:55 +03:00 committed by GitHub
parent c066d7d439
commit b2fcfc573e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 6 deletions

View File

@ -107,11 +107,11 @@ $ git clone https://github.com/ultralytics/yolov3 && cd yolov3
# convert darknet cfg/weights to pytorch model # convert darknet cfg/weights to pytorch model
$ python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.weights')" $ python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.weights')"
Success: converted 'weights/yolov3-spp.weights' to 'converted.pt' Success: converted 'weights/yolov3-spp.weights' to 'weights/yolov3-spp.pt'
# convert cfg/pytorch model to darknet weights # convert cfg/pytorch model to darknet weights
$ python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.pt')" $ python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.pt')"
Success: converted 'weights/yolov3-spp.pt' to 'converted.weights' Success: converted 'weights/yolov3-spp.pt' to 'weights/yolov3-spp.weights'
``` ```
# mAP # mAP

View File

@ -423,8 +423,9 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'):
# Load weights and save # Load weights and save
if weights.endswith('.pt'): # if PyTorch format if weights.endswith('.pt'): # if PyTorch format
model.load_state_dict(torch.load(weights, map_location='cpu')['model']) model.load_state_dict(torch.load(weights, map_location='cpu')['model'])
save_weights(model, path='converted.weights', cutoff=-1) target = weights.rsplit('.', 1)[0] + '.weights'
print("Success: converted '%s' to 'converted.weights'" % weights) save_weights(model, path=target, cutoff=-1)
print("Success: converted '%s' to '%s'" % (weights, target))
elif weights.endswith('.weights'): # darknet format elif weights.endswith('.weights'): # darknet format
_ = load_darknet_weights(model, weights) _ = load_darknet_weights(model, weights)
@ -435,8 +436,9 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'):
'model': model.state_dict(), 'model': model.state_dict(),
'optimizer': None} 'optimizer': None}
torch.save(chkpt, 'converted.pt') target = weights.rsplit('.', 1)[0] + '.pt'
print("Success: converted '%s' to 'converted.pt'" % weights) torch.save(chkpt, target)
print("Success: converted '%s' to '%'" % (weights, target))
else: else:
print('Error: extension not supported.') print('Error: extension not supported.')