updates
This commit is contained in:
parent
8dec060504
commit
08f051c1d4
|
@ -30,9 +30,7 @@ def detect(
|
||||||
if weights.endswith('.pt'): # pytorch format
|
if weights.endswith('.pt'): # pytorch format
|
||||||
if weights.endswith('weights/yolov3.pt') and not os.path.isfile(weights):
|
if weights.endswith('weights/yolov3.pt') and not os.path.isfile(weights):
|
||||||
os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights)
|
os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights)
|
||||||
checkpoint = torch.load(weights, map_location='cpu')
|
model.load_state_dict(torch.load(weights, map_location='cpu')['model'])
|
||||||
model.load_state_dict(checkpoint['model'])
|
|
||||||
del checkpoint
|
|
||||||
else: # darknet format
|
else: # darknet format
|
||||||
load_darknet_weights(model, weights)
|
load_darknet_weights(model, weights)
|
||||||
|
|
||||||
|
|
4
test.py
4
test.py
|
@ -29,9 +29,7 @@ def test(
|
||||||
|
|
||||||
# Load weights
|
# Load weights
|
||||||
if weights.endswith('.pt'): # pytorch format
|
if weights.endswith('.pt'): # pytorch format
|
||||||
checkpoint = torch.load(weights, map_location='cpu')
|
model.load_state_dict(torch.load(weights, map_location='cpu')['model'])
|
||||||
model.load_state_dict(checkpoint['model'])
|
|
||||||
del checkpoint
|
|
||||||
else: # darknet format
|
else: # darknet format
|
||||||
load_darknet_weights(model, weights)
|
load_darknet_weights(model, weights)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue