diff --git a/models.py b/models.py index 0e9d752f..d6e31eaa 100755 --- a/models.py +++ b/models.py @@ -1,8 +1,7 @@ +import torch.nn.functional as F + from utils.parse_config import * from utils.utils import * -from pathlib import Path - -import torch.nn.functional as F ONNX_EXPORT = False @@ -71,9 +70,7 @@ def create_modules(module_defs, img_size): elif mdef['type'] == 'yolo': yolo_index += 1 mask = [int(x) for x in mdef['mask'].split(',')] # anchor mask - a = [float(x) for x in mdef['anchors'].split(',')] # anchors - a = [(a[i], a[i + 1]) for i in range(0, len(a), 2)] - modules = YOLOLayer(anchors=[a[i] for i in mask], # anchor list + modules = YOLOLayer(anchors=mdef['anchors'][mask], # anchor list nc=int(mdef['classes']), # number of classes img_size=img_size, # (416, 416) yolo_index=yolo_index) # 0, 1 or 2 diff --git a/utils/parse_config.py b/utils/parse_config.py index a25eca3f..23581a51 100644 --- a/utils/parse_config.py +++ b/utils/parse_config.py @@ -1,35 +1,42 @@ +import numpy as np + + def parse_model_cfg(path): - """Parses the yolo-v3 layer configuration file and returns module definitions""" + # Parses the yolo-v3 layer configuration file and returns module definitions file = open(path, 'r') lines = file.read().split('\n') lines = [x for x in lines if x and not x.startswith('#')] lines = [x.rstrip().lstrip() for x in lines] # get rid of fringe whitespaces - module_defs = [] + mdefs = [] # module definitions for line in lines: if line.startswith('['): # This marks the start of a new block - module_defs.append({}) - module_defs[-1]['type'] = line[1:-1].rstrip() - if module_defs[-1]['type'] == 'convolutional': - module_defs[-1]['batch_normalize'] = 0 # pre-populate with zeros (may be overwritten later) + mdefs.append({}) + mdefs[-1]['type'] = line[1:-1].rstrip() + if mdefs[-1]['type'] == 'convolutional': + mdefs[-1]['batch_normalize'] = 0 # pre-populate with zeros (may be overwritten later) else: - key, value = line.split("=") - value = value.strip() - module_defs[-1][key.rstrip()] = value.strip() + key, val = line.split("=") + key = key.rstrip() - return module_defs + if 'anchors' in key: + mdefs[-1][key] = np.array([float(x) for x in val.split(',')]).reshape((-1, 2)) # np anchors + else: + mdefs[-1][key] = val.strip() + + return mdefs def parse_data_cfg(path): - """Parses the data configuration file""" + # Parses the data configuration file options = dict() - options['gpus'] = '0,1,2,3' - options['num_workers'] = '10' with open(path, 'r') as fp: lines = fp.readlines() + for line in lines: line = line.strip() if line == '' or line.startswith('#'): continue - key, value = line.split('=') - options[key.strip()] = value.strip() + key, val = line.split('=') + options[key.strip()] = val.strip() + return options diff --git a/utils/utils.py b/utils/utils.py index 988a5ff1..b6e8a678 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -10,7 +10,6 @@ import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn -from PIL import Image from tqdm import tqdm from . import torch_utils # , google_utils