This commit is contained in:
Glenn Jocher 2019-08-15 18:15:27 +02:00
parent be7f4fa72f
commit 5a9fb2411d
3 changed files with 25 additions and 22 deletions

View File

@ -1,8 +1,7 @@
import torch.nn.functional as F
from utils.parse_config import * from utils.parse_config import *
from utils.utils import * from utils.utils import *
from pathlib import Path
import torch.nn.functional as F
ONNX_EXPORT = False ONNX_EXPORT = False
@ -71,9 +70,7 @@ def create_modules(module_defs, img_size):
elif mdef['type'] == 'yolo': elif mdef['type'] == 'yolo':
yolo_index += 1 yolo_index += 1
mask = [int(x) for x in mdef['mask'].split(',')] # anchor mask mask = [int(x) for x in mdef['mask'].split(',')] # anchor mask
a = [float(x) for x in mdef['anchors'].split(',')] # anchors modules = YOLOLayer(anchors=mdef['anchors'][mask], # anchor list
a = [(a[i], a[i + 1]) for i in range(0, len(a), 2)]
modules = YOLOLayer(anchors=[a[i] for i in mask], # anchor list
nc=int(mdef['classes']), # number of classes nc=int(mdef['classes']), # number of classes
img_size=img_size, # (416, 416) img_size=img_size, # (416, 416)
yolo_index=yolo_index) # 0, 1 or 2 yolo_index=yolo_index) # 0, 1 or 2

View File

@ -1,35 +1,42 @@
import numpy as np
def parse_model_cfg(path): def parse_model_cfg(path):
"""Parses the yolo-v3 layer configuration file and returns module definitions""" # Parses the yolo-v3 layer configuration file and returns module definitions
file = open(path, 'r') file = open(path, 'r')
lines = file.read().split('\n') lines = file.read().split('\n')
lines = [x for x in lines if x and not x.startswith('#')] lines = [x for x in lines if x and not x.startswith('#')]
lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespaces lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespaces
module_defs = [] mdefs = [] # module definitions
for line in lines: for line in lines:
if line.startswith('['): # This marks the start of a new block if line.startswith('['): # This marks the start of a new block
module_defs.append({}) mdefs.append({})
module_defs[-1]['type'] = line[1:-1].rstrip() mdefs[-1]['type'] = line[1:-1].rstrip()
if module_defs[-1]['type'] == 'convolutional': if mdefs[-1]['type'] == 'convolutional':
module_defs[-1]['batch_normalize'] = 0 # pre-populate with zeros (may be overwritten later) mdefs[-1]['batch_normalize'] = 0 # pre-populate with zeros (may be overwritten later)
else: else:
key, value = line.split("=") key, val = line.split("=")
value = value.strip() key = key.rstrip()
module_defs[-1][key.rstrip()] = value.strip()
return module_defs if 'anchors' in key:
mdefs[-1][key] = np.array([float(x) for x in val.split(',')]).reshape((-1, 2)) # np anchors
else:
mdefs[-1][key] = val.strip()
return mdefs
def parse_data_cfg(path): def parse_data_cfg(path):
"""Parses the data configuration file""" # Parses the data configuration file
options = dict() options = dict()
options['gpus'] = '0,1,2,3'
options['num_workers'] = '10'
with open(path, 'r') as fp: with open(path, 'r') as fp:
lines = fp.readlines() lines = fp.readlines()
for line in lines: for line in lines:
line = line.strip() line = line.strip()
if line == '' or line.startswith('#'): if line == '' or line.startswith('#'):
continue continue
key, value = line.split('=') key, val = line.split('=')
options[key.strip()] = value.strip() options[key.strip()] = val.strip()
return options return options

View File

@ -10,7 +10,6 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from . import torch_utils # , google_utils from . import torch_utils # , google_utils