convert(...) changed to save converted file alongside the original file (#1167)
This commit is contained in:
parent
c066d7d439
commit
b2fcfc573e
|
@ -107,11 +107,11 @@ $ git clone https://github.com/ultralytics/yolov3 && cd yolov3
|
|||
|
||||
# convert darknet cfg/weights to pytorch model
|
||||
$ python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.weights')"
|
||||
Success: converted 'weights/yolov3-spp.weights' to 'converted.pt'
|
||||
Success: converted 'weights/yolov3-spp.weights' to 'weights/yolov3-spp.pt'
|
||||
|
||||
# convert cfg/pytorch model to darknet weights
|
||||
$ python3 -c "from models import *; convert('cfg/yolov3-spp.cfg', 'weights/yolov3-spp.pt')"
|
||||
Success: converted 'weights/yolov3-spp.pt' to 'converted.weights'
|
||||
Success: converted 'weights/yolov3-spp.pt' to 'weights/yolov3-spp.weights'
|
||||
```
|
||||
|
||||
# mAP
|
||||
|
|
10
models.py
10
models.py
|
@ -423,8 +423,9 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'):
|
|||
# Load weights and save
|
||||
if weights.endswith('.pt'): # if PyTorch format
|
||||
model.load_state_dict(torch.load(weights, map_location='cpu')['model'])
|
||||
save_weights(model, path='converted.weights', cutoff=-1)
|
||||
print("Success: converted '%s' to 'converted.weights'" % weights)
|
||||
target = weights.rsplit('.', 1)[0] + '.weights'
|
||||
save_weights(model, path=target, cutoff=-1)
|
||||
print("Success: converted '%s' to '%s'" % (weights, target))
|
||||
|
||||
elif weights.endswith('.weights'): # darknet format
|
||||
_ = load_darknet_weights(model, weights)
|
||||
|
@ -435,8 +436,9 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'):
|
|||
'model': model.state_dict(),
|
||||
'optimizer': None}
|
||||
|
||||
torch.save(chkpt, 'converted.pt')
|
||||
print("Success: converted '%s' to 'converted.pt'" % weights)
|
||||
target = weights.rsplit('.', 1)[0] + '.pt'
|
||||
torch.save(chkpt, target)
|
||||
print("Success: converted '%s' to '%'" % (weights, target))
|
||||
|
||||
else:
|
||||
print('Error: extension not supported.')
|
||||
|
|
Loading…
Reference in New Issue