updates
This commit is contained in:
parent
870020ed15
commit
c24702941f
54
models.py
54
models.py
|
@ -291,27 +291,9 @@ def create_grids(self, img_size=416, ng=(13, 13), device='cpu', type=torch.float
|
|||
|
||||
def load_darknet_weights(self, weights, cutoff=-1):
|
||||
# Parses and loads the weights stored in 'weights'
|
||||
# cutoff: save layers between 0 and cutoff (if cutoff = -1 all are saved)
|
||||
|
||||
# Establish cutoffs (load layers between 0 and cutoff. if cutoff = -1 all are loaded)
|
||||
file = Path(weights).name
|
||||
|
||||
# Try to download weights if not available locally
|
||||
msg = weights + ' missing, download from https://drive.google.com/drive/folders/1uxgUBemJVw9wZsdpboYbzUN4bcRhsuAI'
|
||||
if not os.path.isfile(weights):
|
||||
if file == 'yolov3-spp.weights':
|
||||
gdrive_download(id='1oPCHKsM2JpM-zgyepQciGli9X0MTsJCO', name=weights)
|
||||
elif file == 'darknet53.conv.74':
|
||||
gdrive_download(id='18xqvs_uwAqfTXp-LJCYLYNHBOcrwbrp0', name=weights)
|
||||
else:
|
||||
try: # download from pjreddie.com
|
||||
url = 'https://pjreddie.com/media/files/' + file
|
||||
print('Downloading ' + url)
|
||||
os.system('curl -f ' + url + ' -o ' + weights)
|
||||
except IOError:
|
||||
print(msg)
|
||||
os.system('rm ' + weights) # remove partial downloads
|
||||
assert os.path.exists(weights), msg # download missing weights from Google Drive
|
||||
|
||||
# Establish cutoffs
|
||||
if file == 'darknet53.conv.74':
|
||||
cutoff = 75
|
||||
elif file == 'yolov3-tiny.conv.15':
|
||||
|
@ -417,3 +399,35 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'):
|
|||
|
||||
else:
|
||||
print('Error: extension not supported.')
|
||||
|
||||
|
||||
def attempt_download(weights):
|
||||
# Attempt to download pretrained weights if not found locally
|
||||
|
||||
msg = weights + ' missing, download from https://drive.google.com/drive/folders/1uxgUBemJVw9wZsdpboYbzUN4bcRhsuAI'
|
||||
if not os.path.isfile(weights):
|
||||
file = Path(weights).name
|
||||
|
||||
if file == 'yolov3-spp.weights':
|
||||
gdrive_download(id='1oPCHKsM2JpM-zgyepQciGli9X0MTsJCO', name=weights)
|
||||
elif file == 'yolov3-spp.pt':
|
||||
gdrive_download(id='1vFlbJ_dXPvtwaLLOu-twnjK4exdFiQ73', name=weights)
|
||||
elif file == 'yolov3.pt':
|
||||
gdrive_download(id='11uy0ybbOXA2hc-NJkJbbbkDwNX1QZDlz', name=weights)
|
||||
elif file == 'yolov3-tiny.pt':
|
||||
gdrive_download(id='1qKSgejNeNczgNNiCn9ZF_o55GFk1DjY_', name=weights)
|
||||
elif file == 'darknet53.conv.74':
|
||||
gdrive_download(id='18xqvs_uwAqfTXp-LJCYLYNHBOcrwbrp0', name=weights)
|
||||
elif file == 'yolov3-tiny.conv.15':
|
||||
gdrive_download(id='140PnSedCsGGgu3rOD6Ez4oI6cdDzerLC', name=weights)
|
||||
|
||||
else:
|
||||
try: # download from pjreddie.com
|
||||
url = 'https://pjreddie.com/media/files/' + file
|
||||
print('Downloading ' + url)
|
||||
os.system('curl -f ' + url + ' -o ' + weights)
|
||||
except IOError:
|
||||
print(msg)
|
||||
os.system('rm ' + weights) # remove partial downloads
|
||||
|
||||
assert os.path.exists(weights), msg # download missing weights from Google Drive
|
||||
|
|
1
test.py
1
test.py
|
@ -27,6 +27,7 @@ def test(cfg,
|
|||
model = Darknet(cfg, img_size).to(device)
|
||||
|
||||
# Load weights
|
||||
attempt_download(weights)
|
||||
if weights.endswith('.pt'): # pytorch format
|
||||
model.load_state_dict(torch.load(weights, map_location=device)['model'])
|
||||
else: # darknet format
|
||||
|
|
Loading…
Reference in New Issue