updates
This commit is contained in:
parent
be7f4fa72f
commit
5a9fb2411d
|
@ -1,8 +1,7 @@
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from utils.parse_config import *
|
from utils.parse_config import *
|
||||||
from utils.utils import *
|
from utils.utils import *
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
ONNX_EXPORT = False
|
ONNX_EXPORT = False
|
||||||
|
|
||||||
|
@ -71,9 +70,7 @@ def create_modules(module_defs, img_size):
|
||||||
elif mdef['type'] == 'yolo':
|
elif mdef['type'] == 'yolo':
|
||||||
yolo_index += 1
|
yolo_index += 1
|
||||||
mask = [int(x) for x in mdef['mask'].split(',')] # anchor mask
|
mask = [int(x) for x in mdef['mask'].split(',')] # anchor mask
|
||||||
a = [float(x) for x in mdef['anchors'].split(',')] # anchors
|
modules = YOLOLayer(anchors=mdef['anchors'][mask], # anchor list
|
||||||
a = [(a[i], a[i + 1]) for i in range(0, len(a), 2)]
|
|
||||||
modules = YOLOLayer(anchors=[a[i] for i in mask], # anchor list
|
|
||||||
nc=int(mdef['classes']), # number of classes
|
nc=int(mdef['classes']), # number of classes
|
||||||
img_size=img_size, # (416, 416)
|
img_size=img_size, # (416, 416)
|
||||||
yolo_index=yolo_index) # 0, 1 or 2
|
yolo_index=yolo_index) # 0, 1 or 2
|
||||||
|
|
|
@ -1,35 +1,42 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def parse_model_cfg(path):
|
def parse_model_cfg(path):
|
||||||
"""Parses the yolo-v3 layer configuration file and returns module definitions"""
|
# Parses the yolo-v3 layer configuration file and returns module definitions
|
||||||
file = open(path, 'r')
|
file = open(path, 'r')
|
||||||
lines = file.read().split('\n')
|
lines = file.read().split('\n')
|
||||||
lines = [x for x in lines if x and not x.startswith('#')]
|
lines = [x for x in lines if x and not x.startswith('#')]
|
||||||
lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespaces
|
lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespaces
|
||||||
module_defs = []
|
mdefs = [] # module definitions
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if line.startswith('['): # This marks the start of a new block
|
if line.startswith('['): # This marks the start of a new block
|
||||||
module_defs.append({})
|
mdefs.append({})
|
||||||
module_defs[-1]['type'] = line[1:-1].rstrip()
|
mdefs[-1]['type'] = line[1:-1].rstrip()
|
||||||
if module_defs[-1]['type'] == 'convolutional':
|
if mdefs[-1]['type'] == 'convolutional':
|
||||||
module_defs[-1]['batch_normalize'] = 0 # pre-populate with zeros (may be overwritten later)
|
mdefs[-1]['batch_normalize'] = 0 # pre-populate with zeros (may be overwritten later)
|
||||||
else:
|
else:
|
||||||
key, value = line.split("=")
|
key, val = line.split("=")
|
||||||
value = value.strip()
|
key = key.rstrip()
|
||||||
module_defs[-1][key.rstrip()] = value.strip()
|
|
||||||
|
|
||||||
return module_defs
|
if 'anchors' in key:
|
||||||
|
mdefs[-1][key] = np.array([float(x) for x in val.split(',')]).reshape((-1, 2)) # np anchors
|
||||||
|
else:
|
||||||
|
mdefs[-1][key] = val.strip()
|
||||||
|
|
||||||
|
return mdefs
|
||||||
|
|
||||||
|
|
||||||
def parse_data_cfg(path):
|
def parse_data_cfg(path):
|
||||||
"""Parses the data configuration file"""
|
# Parses the data configuration file
|
||||||
options = dict()
|
options = dict()
|
||||||
options['gpus'] = '0,1,2,3'
|
|
||||||
options['num_workers'] = '10'
|
|
||||||
with open(path, 'r') as fp:
|
with open(path, 'r') as fp:
|
||||||
lines = fp.readlines()
|
lines = fp.readlines()
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line == '' or line.startswith('#'):
|
if line == '' or line.startswith('#'):
|
||||||
continue
|
continue
|
||||||
key, value = line.split('=')
|
key, val = line.split('=')
|
||||||
options[key.strip()] = value.strip()
|
options[key.strip()] = val.strip()
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
|
@ -10,7 +10,6 @@ import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from PIL import Image
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from . import torch_utils # , google_utils
|
from . import torch_utils # , google_utils
|
||||||
|
|
Loading…
Reference in New Issue