updates
This commit is contained in:
parent
4af819449c
commit
272b9c7c11
|
@ -1,5 +1,6 @@
|
|||
from utils.parse_config import *
|
||||
from utils.utils import *
|
||||
from pathlib import Path
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
@ -261,21 +262,21 @@ def create_grids(self, img_size=416, ng=(13, 13), device='cpu'):
|
|||
def load_darknet_weights(self, weights, cutoff=-1):
|
||||
# Parses and loads the weights stored in 'weights'
|
||||
# cutoff: save layers between 0 and cutoff (if cutoff = -1 all are saved)
|
||||
weights_file = weights.split(os.sep)[-1]
|
||||
file = Path(weights).name
|
||||
|
||||
# Try to download weights if not available locally
|
||||
if not os.path.isfile(weights):
|
||||
try:
|
||||
url = 'https://pjreddie.com/media/files/' + weights_file
|
||||
url = 'https://pjreddie.com/media/files/' + file
|
||||
print('Downloading ' + url)
|
||||
os.system('curl ' + url + ' -o ' + weights)
|
||||
except IOError:
|
||||
print(weights + ' not found.\nTry https://drive.google.com/drive/folders/1uxgUBemJVw9wZsdpboYbzUN4bcRhsuAI')
|
||||
|
||||
# Establish cutoffs
|
||||
if weights_file == 'darknet53.conv.74':
|
||||
if file == 'darknet53.conv.74':
|
||||
cutoff = 75
|
||||
elif weights_file == 'yolov3-tiny.conv.15':
|
||||
elif file == 'yolov3-tiny.conv.15':
|
||||
cutoff = 15
|
||||
|
||||
# Read weights file
|
||||
|
|
Loading…
Reference in New Issue