This commit is contained in:
Glenn Jocher 2019-09-19 18:05:04 +02:00
parent 870020ed15
commit c24702941f
3 changed files with 36 additions and 20 deletions

View File

@ -291,27 +291,9 @@ def create_grids(self, img_size=416, ng=(13, 13), device='cpu', type=torch.float
def load_darknet_weights(self, weights, cutoff=-1):
# Parses and loads the weights stored in 'weights'
# cutoff: save layers between 0 and cutoff (if cutoff = -1 all are saved)
# Establish cutoffs (load layers between 0 and cutoff. if cutoff = -1 all are loaded)
file = Path(weights).name
# Try to download weights if not available locally
msg = weights + ' missing, download from https://drive.google.com/drive/folders/1uxgUBemJVw9wZsdpboYbzUN4bcRhsuAI'
if not os.path.isfile(weights):
if file == 'yolov3-spp.weights':
gdrive_download(id='1oPCHKsM2JpM-zgyepQciGli9X0MTsJCO', name=weights)
elif file == 'darknet53.conv.74':
gdrive_download(id='18xqvs_uwAqfTXp-LJCYLYNHBOcrwbrp0', name=weights)
else:
try: # download from pjreddie.com
url = 'https://pjreddie.com/media/files/' + file
print('Downloading ' + url)
os.system('curl -f ' + url + ' -o ' + weights)
except IOError:
print(msg)
os.system('rm ' + weights) # remove partial downloads
assert os.path.exists(weights), msg # download missing weights from Google Drive
# Establish cutoffs
if file == 'darknet53.conv.74':
cutoff = 75
elif file == 'yolov3-tiny.conv.15':
@ -417,3 +399,35 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'):
else:
print('Error: extension not supported.')
def attempt_download(weights):
# Attempt to download pretrained weights if not found locally
msg = weights + ' missing, download from https://drive.google.com/drive/folders/1uxgUBemJVw9wZsdpboYbzUN4bcRhsuAI'
if not os.path.isfile(weights):
file = Path(weights).name
if file == 'yolov3-spp.weights':
gdrive_download(id='1oPCHKsM2JpM-zgyepQciGli9X0MTsJCO', name=weights)
elif file == 'yolov3-spp.pt':
gdrive_download(id='1vFlbJ_dXPvtwaLLOu-twnjK4exdFiQ73', name=weights)
elif file == 'yolov3.pt':
gdrive_download(id='11uy0ybbOXA2hc-NJkJbbbkDwNX1QZDlz', name=weights)
elif file == 'yolov3-tiny.pt':
gdrive_download(id='1qKSgejNeNczgNNiCn9ZF_o55GFk1DjY_', name=weights)
elif file == 'darknet53.conv.74':
gdrive_download(id='18xqvs_uwAqfTXp-LJCYLYNHBOcrwbrp0', name=weights)
elif file == 'yolov3-tiny.conv.15':
gdrive_download(id='140PnSedCsGGgu3rOD6Ez4oI6cdDzerLC', name=weights)
else:
try: # download from pjreddie.com
url = 'https://pjreddie.com/media/files/' + file
print('Downloading ' + url)
os.system('curl -f ' + url + ' -o ' + weights)
except IOError:
print(msg)
os.system('rm ' + weights) # remove partial downloads
assert os.path.exists(weights), msg # download missing weights from Google Drive

View File

@ -27,6 +27,7 @@ def test(cfg,
model = Darknet(cfg, img_size).to(device)
# Load weights
attempt_download(weights)
if weights.endswith('.pt'): # pytorch format
model.load_state_dict(torch.load(weights, map_location=device)['model'])
else: # darknet format

View File

@ -100,6 +100,7 @@ def train():
cutoff = -1 # backbone reaches to cutoff layer
start_epoch = 0
best_fitness = 0.
attempt_download(weights)
if weights.endswith('.pt'): # pytorch format
# possible weights are 'last.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc.
if opt.bucket: