This commit is contained in:
Glenn Jocher 2019-07-20 15:10:31 +02:00
parent 39f63b7110
commit deb200f6bf
3 changed files with 14 additions and 14 deletions

View File

@ -8,7 +8,7 @@ from utils.utils import *
def detect(cfg,
data_cfg,
data,
weights,
images='data/samples', # input folder
output='output', # output folder
@ -59,7 +59,7 @@ def detect(cfg,
dataloader = LoadImages(images, img_size=img_size)
# Get classes and colors
classes = load_classes(parse_data_cfg(data_cfg)['names'])
classes = load_classes(parse_data_cfg(data)['names'])
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(classes))]
for i, (path, img, im0, vid_cap) in enumerate(dataloader):
@ -133,7 +133,7 @@ if __name__ == '__main__':
with torch.no_grad():
detect(opt.cfg,
opt.data_cfg,
opt.data,
opt.weights,
images=opt.images,
img_size=opt.img_size,

12
test.py
View File

@ -9,7 +9,7 @@ from utils.utils import *
def test(cfg,
data_cfg,
data,
weights=None,
batch_size=16,
img_size=416,
@ -37,10 +37,10 @@ def test(cfg,
device = next(model.parameters()).device # get model device
# Configure run
data_cfg = parse_data_cfg(data_cfg)
nc = int(data_cfg['classes']) # number of classes
test_path = data_cfg['valid'] # path to test images
names = load_classes(data_cfg['names']) # class names
data = parse_data_cfg(data)
nc = int(data['classes']) # number of classes
test_path = data['valid'] # path to test images
names = load_classes(data['names']) # class names
# Dataloader
dataset = LoadImagesAndLabels(test_path, img_size, batch_size)
@ -203,7 +203,7 @@ if __name__ == '__main__':
with torch.no_grad():
mAP = test(opt.cfg,
opt.data_cfg,
opt.data,
opt.weights,
opt.batch_size,
opt.img_size,

View File

@ -62,7 +62,7 @@ hyp = {'giou': 1.153, # giou loss gain
def train(cfg,
data_cfg,
data,
img_size=416,
epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs
batch_size=16,
@ -81,7 +81,7 @@ def train(cfg,
img_size = img_sz_max * 32 # initiate with maximum multi_scale size
# Configure run
data_dict = parse_data_cfg(data_cfg)
data_dict = parse_data_cfg(data)
train_path = data_dict['train']
nc = int(data_dict['classes']) # number of classes
@ -270,7 +270,7 @@ def train(cfg,
# Calculate mAP (always test final epoch, skip first 5 if opt.nosave)
if not (opt.notest or (opt.nosave and epoch < 10)) or epoch == epochs - 1:
with torch.no_grad():
results, maps = test.test(cfg, data_cfg, batch_size=batch_size, img_size=opt.img_size, model=model,
results, maps = test.test(cfg, data, batch_size=batch_size, img_size=opt.img_size, model=model,
conf_thres=0.1)
# Write epoch results
@ -361,7 +361,7 @@ if __name__ == '__main__':
# Train
results = train(opt.cfg,
opt.data_cfg,
opt.data,
img_size=opt.img_size,
epochs=opt.epochs,
batch_size=opt.batch_size,
@ -393,7 +393,7 @@ if __name__ == '__main__':
# Train mutation
results = train(opt.cfg,
opt.data_cfg,
opt.data,
img_size=opt.img_size,
epochs=opt.epochs,
batch_size=opt.batch_size,