This commit is contained in:
Glenn Jocher 2019-07-20 15:10:31 +02:00
parent 39f63b7110
commit deb200f6bf
3 changed files with 14 additions and 14 deletions

View File

@ -8,7 +8,7 @@ from utils.utils import *
def detect(cfg, def detect(cfg,
data_cfg, data,
weights, weights,
images='data/samples', # input folder images='data/samples', # input folder
output='output', # output folder output='output', # output folder
@ -59,7 +59,7 @@ def detect(cfg,
dataloader = LoadImages(images, img_size=img_size) dataloader = LoadImages(images, img_size=img_size)
# Get classes and colors # Get classes and colors
classes = load_classes(parse_data_cfg(data_cfg)['names']) classes = load_classes(parse_data_cfg(data)['names'])
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(classes))] colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(classes))]
for i, (path, img, im0, vid_cap) in enumerate(dataloader): for i, (path, img, im0, vid_cap) in enumerate(dataloader):
@ -133,7 +133,7 @@ if __name__ == '__main__':
with torch.no_grad(): with torch.no_grad():
detect(opt.cfg, detect(opt.cfg,
opt.data_cfg, opt.data,
opt.weights, opt.weights,
images=opt.images, images=opt.images,
img_size=opt.img_size, img_size=opt.img_size,

12
test.py
View File

@ -9,7 +9,7 @@ from utils.utils import *
def test(cfg, def test(cfg,
data_cfg, data,
weights=None, weights=None,
batch_size=16, batch_size=16,
img_size=416, img_size=416,
@ -37,10 +37,10 @@ def test(cfg,
device = next(model.parameters()).device # get model device device = next(model.parameters()).device # get model device
# Configure run # Configure run
data_cfg = parse_data_cfg(data_cfg) data = parse_data_cfg(data)
nc = int(data_cfg['classes']) # number of classes nc = int(data['classes']) # number of classes
test_path = data_cfg['valid'] # path to test images test_path = data['valid'] # path to test images
names = load_classes(data_cfg['names']) # class names names = load_classes(data['names']) # class names
# Dataloader # Dataloader
dataset = LoadImagesAndLabels(test_path, img_size, batch_size) dataset = LoadImagesAndLabels(test_path, img_size, batch_size)
@ -203,7 +203,7 @@ if __name__ == '__main__':
with torch.no_grad(): with torch.no_grad():
mAP = test(opt.cfg, mAP = test(opt.cfg,
opt.data_cfg, opt.data,
opt.weights, opt.weights,
opt.batch_size, opt.batch_size,
opt.img_size, opt.img_size,

View File

@ -62,7 +62,7 @@ hyp = {'giou': 1.153, # giou loss gain
def train(cfg, def train(cfg,
data_cfg, data,
img_size=416, img_size=416,
epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs
batch_size=16, batch_size=16,
@ -81,7 +81,7 @@ def train(cfg,
img_size = img_sz_max * 32 # initiate with maximum multi_scale size img_size = img_sz_max * 32 # initiate with maximum multi_scale size
# Configure run # Configure run
data_dict = parse_data_cfg(data_cfg) data_dict = parse_data_cfg(data)
train_path = data_dict['train'] train_path = data_dict['train']
nc = int(data_dict['classes']) # number of classes nc = int(data_dict['classes']) # number of classes
@ -270,7 +270,7 @@ def train(cfg,
# Calculate mAP (always test final epoch, skip first 5 if opt.nosave) # Calculate mAP (always test final epoch, skip first 5 if opt.nosave)
if not (opt.notest or (opt.nosave and epoch < 10)) or epoch == epochs - 1: if not (opt.notest or (opt.nosave and epoch < 10)) or epoch == epochs - 1:
with torch.no_grad(): with torch.no_grad():
results, maps = test.test(cfg, data_cfg, batch_size=batch_size, img_size=opt.img_size, model=model, results, maps = test.test(cfg, data, batch_size=batch_size, img_size=opt.img_size, model=model,
conf_thres=0.1) conf_thres=0.1)
# Write epoch results # Write epoch results
@ -361,7 +361,7 @@ if __name__ == '__main__':
# Train # Train
results = train(opt.cfg, results = train(opt.cfg,
opt.data_cfg, opt.data,
img_size=opt.img_size, img_size=opt.img_size,
epochs=opt.epochs, epochs=opt.epochs,
batch_size=opt.batch_size, batch_size=opt.batch_size,
@ -393,7 +393,7 @@ if __name__ == '__main__':
# Train mutation # Train mutation
results = train(opt.cfg, results = train(opt.cfg,
opt.data_cfg, opt.data,
img_size=opt.img_size, img_size=opt.img_size,
epochs=opt.epochs, epochs=opt.epochs,
batch_size=opt.batch_size, batch_size=opt.batch_size,