updates
This commit is contained in:
parent
870020ed15
commit
c24702941f
54
models.py
54
models.py
|
@ -291,27 +291,9 @@ def create_grids(self, img_size=416, ng=(13, 13), device='cpu', type=torch.float
|
||||||
|
|
||||||
def load_darknet_weights(self, weights, cutoff=-1):
|
def load_darknet_weights(self, weights, cutoff=-1):
|
||||||
# Parses and loads the weights stored in 'weights'
|
# Parses and loads the weights stored in 'weights'
|
||||||
# cutoff: save layers between 0 and cutoff (if cutoff = -1 all are saved)
|
|
||||||
|
# Establish cutoffs (load layers between 0 and cutoff. if cutoff = -1 all are loaded)
|
||||||
file = Path(weights).name
|
file = Path(weights).name
|
||||||
|
|
||||||
# Try to download weights if not available locally
|
|
||||||
msg = weights + ' missing, download from https://drive.google.com/drive/folders/1uxgUBemJVw9wZsdpboYbzUN4bcRhsuAI'
|
|
||||||
if not os.path.isfile(weights):
|
|
||||||
if file == 'yolov3-spp.weights':
|
|
||||||
gdrive_download(id='1oPCHKsM2JpM-zgyepQciGli9X0MTsJCO', name=weights)
|
|
||||||
elif file == 'darknet53.conv.74':
|
|
||||||
gdrive_download(id='18xqvs_uwAqfTXp-LJCYLYNHBOcrwbrp0', name=weights)
|
|
||||||
else:
|
|
||||||
try: # download from pjreddie.com
|
|
||||||
url = 'https://pjreddie.com/media/files/' + file
|
|
||||||
print('Downloading ' + url)
|
|
||||||
os.system('curl -f ' + url + ' -o ' + weights)
|
|
||||||
except IOError:
|
|
||||||
print(msg)
|
|
||||||
os.system('rm ' + weights) # remove partial downloads
|
|
||||||
assert os.path.exists(weights), msg # download missing weights from Google Drive
|
|
||||||
|
|
||||||
# Establish cutoffs
|
|
||||||
if file == 'darknet53.conv.74':
|
if file == 'darknet53.conv.74':
|
||||||
cutoff = 75
|
cutoff = 75
|
||||||
elif file == 'yolov3-tiny.conv.15':
|
elif file == 'yolov3-tiny.conv.15':
|
||||||
|
@ -417,3 +399,35 @@ def convert(cfg='cfg/yolov3-spp.cfg', weights='weights/yolov3-spp.weights'):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print('Error: extension not supported.')
|
print('Error: extension not supported.')
|
||||||
|
|
||||||
|
|
||||||
|
def attempt_download(weights):
|
||||||
|
# Attempt to download pretrained weights if not found locally
|
||||||
|
|
||||||
|
msg = weights + ' missing, download from https://drive.google.com/drive/folders/1uxgUBemJVw9wZsdpboYbzUN4bcRhsuAI'
|
||||||
|
if not os.path.isfile(weights):
|
||||||
|
file = Path(weights).name
|
||||||
|
|
||||||
|
if file == 'yolov3-spp.weights':
|
||||||
|
gdrive_download(id='1oPCHKsM2JpM-zgyepQciGli9X0MTsJCO', name=weights)
|
||||||
|
elif file == 'yolov3-spp.pt':
|
||||||
|
gdrive_download(id='1vFlbJ_dXPvtwaLLOu-twnjK4exdFiQ73', name=weights)
|
||||||
|
elif file == 'yolov3.pt':
|
||||||
|
gdrive_download(id='11uy0ybbOXA2hc-NJkJbbbkDwNX1QZDlz', name=weights)
|
||||||
|
elif file == 'yolov3-tiny.pt':
|
||||||
|
gdrive_download(id='1qKSgejNeNczgNNiCn9ZF_o55GFk1DjY_', name=weights)
|
||||||
|
elif file == 'darknet53.conv.74':
|
||||||
|
gdrive_download(id='18xqvs_uwAqfTXp-LJCYLYNHBOcrwbrp0', name=weights)
|
||||||
|
elif file == 'yolov3-tiny.conv.15':
|
||||||
|
gdrive_download(id='140PnSedCsGGgu3rOD6Ez4oI6cdDzerLC', name=weights)
|
||||||
|
|
||||||
|
else:
|
||||||
|
try: # download from pjreddie.com
|
||||||
|
url = 'https://pjreddie.com/media/files/' + file
|
||||||
|
print('Downloading ' + url)
|
||||||
|
os.system('curl -f ' + url + ' -o ' + weights)
|
||||||
|
except IOError:
|
||||||
|
print(msg)
|
||||||
|
os.system('rm ' + weights) # remove partial downloads
|
||||||
|
|
||||||
|
assert os.path.exists(weights), msg # download missing weights from Google Drive
|
||||||
|
|
1
test.py
1
test.py
|
@ -27,6 +27,7 @@ def test(cfg,
|
||||||
model = Darknet(cfg, img_size).to(device)
|
model = Darknet(cfg, img_size).to(device)
|
||||||
|
|
||||||
# Load weights
|
# Load weights
|
||||||
|
attempt_download(weights)
|
||||||
if weights.endswith('.pt'): # pytorch format
|
if weights.endswith('.pt'): # pytorch format
|
||||||
model.load_state_dict(torch.load(weights, map_location=device)['model'])
|
model.load_state_dict(torch.load(weights, map_location=device)['model'])
|
||||||
else: # darknet format
|
else: # darknet format
|
||||||
|
|
1
train.py
1
train.py
|
@ -100,6 +100,7 @@ def train():
|
||||||
cutoff = -1 # backbone reaches to cutoff layer
|
cutoff = -1 # backbone reaches to cutoff layer
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
best_fitness = 0.
|
best_fitness = 0.
|
||||||
|
attempt_download(weights)
|
||||||
if weights.endswith('.pt'): # pytorch format
|
if weights.endswith('.pt'): # pytorch format
|
||||||
# possible weights are 'last.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc.
|
# possible weights are 'last.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc.
|
||||||
if opt.bucket:
|
if opt.bucket:
|
||||||
|
|
Loading…
Reference in New Issue