diff --git a/detect.py b/detect.py index b1afa028..0bf70621 100755 --- a/detect.py +++ b/detect.py @@ -36,12 +36,11 @@ def detect( if weights_file_path.endswith('.pt'): # pytorch format if weights_file_path.endswith('weights/yolov3.pt') and not os.path.isfile(weights_file_path): os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights_file_path) - else: # darknet format - load_weights(model, weights_file_path) - checkpoint = torch.load(weights_file_path, map_location='cpu') model.load_state_dict(checkpoint['model']) del checkpoint + else: # darknet format + load_weights(model, weights_file_path) # current = model.state_dict() # saved = checkpoint['model']