From b2fcfc573e5418c0b2ef0c0357bf51bc5cb027b6 Mon Sep 17 00:00:00 2001 From: IlyaOvodov <34230114+IlyaOvodov@users.noreply.github.com> Date: Wed, 13 May 2020 19:08:55 +0300 Subject: [PATCH] convert(...) changed to save converted file alongside the original file (#1167) --- README.md | 4 ++-- models.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index baa2ccbf..c1e62d0f 100755 --- a/README.md +++ b/README.md @@ -107,11 +107,11 @@ $ git clone https://github.com/ultralytics/yolov3 && cd yolov3 # convert darknet cfg/weights to pytorch model $ python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.weights')" -Success: converted 'weights/yolov3-spp.weights' to 'converted.pt' +Success: converted 'weights/yolov3-spp.weights' to 'weights/yolov3-spp.pt' # convert cfg/pytorch model to darknet weights $ python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.pt')" -Success: converted 'weights/yolov3-spp.pt' to 'converted.weights' +Success: converted 'weights/yolov3-spp.pt' to 'weights/yolov3-spp.weights' ``` # mAP diff --git a/models.py b/models.py index afd5b87b..ebe151b6 100755 --- a/models.py +++ b/models.py @@ -423,8 +423,9 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'): # Load weights and save if weights.endswith('.pt'): # if PyTorch format model.load_state_dict(torch.load(weights, map_location='cpu')['model']) - save_weights(model, path='converted.weights', cutoff=-1) - print("Success: converted '%s' to 'converted.weights'" % weights) + target = weights.rsplit('.', 1)[0] + '.weights' + save_weights(model, path=target, cutoff=-1) + print("Success: converted '%s' to '%s'" % (weights, target)) elif weights.endswith('.weights'): # darknet format _ = load_darknet_weights(model, weights) @@ -435,8 +436,9 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'): 'model': model.state_dict(), 'optimizer': None} - torch.save(chkpt, 'converted.pt') - print("Success: converted '%s' to 'converted.pt'" % weights) + target = weights.rsplit('.', 1)[0] + '.pt' + torch.save(chkpt, target) + print("Success: converted '%s' to '%'" % (weights, target)) else: print('Error: extension not supported.')