updates
This commit is contained in:
parent
39f63b7110
commit
deb200f6bf
|
@ -8,7 +8,7 @@ from utils.utils import *
|
||||||
|
|
||||||
|
|
||||||
def detect(cfg,
|
def detect(cfg,
|
||||||
data_cfg,
|
data,
|
||||||
weights,
|
weights,
|
||||||
images='data/samples', # input folder
|
images='data/samples', # input folder
|
||||||
output='output', # output folder
|
output='output', # output folder
|
||||||
|
@ -59,7 +59,7 @@ def detect(cfg,
|
||||||
dataloader = LoadImages(images, img_size=img_size)
|
dataloader = LoadImages(images, img_size=img_size)
|
||||||
|
|
||||||
# Get classes and colors
|
# Get classes and colors
|
||||||
classes = load_classes(parse_data_cfg(data_cfg)['names'])
|
classes = load_classes(parse_data_cfg(data)['names'])
|
||||||
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(classes))]
|
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(classes))]
|
||||||
|
|
||||||
for i, (path, img, im0, vid_cap) in enumerate(dataloader):
|
for i, (path, img, im0, vid_cap) in enumerate(dataloader):
|
||||||
|
@ -133,7 +133,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
detect(opt.cfg,
|
detect(opt.cfg,
|
||||||
opt.data_cfg,
|
opt.data,
|
||||||
opt.weights,
|
opt.weights,
|
||||||
images=opt.images,
|
images=opt.images,
|
||||||
img_size=opt.img_size,
|
img_size=opt.img_size,
|
||||||
|
|
12
test.py
12
test.py
|
@ -9,7 +9,7 @@ from utils.utils import *
|
||||||
|
|
||||||
|
|
||||||
def test(cfg,
|
def test(cfg,
|
||||||
data_cfg,
|
data,
|
||||||
weights=None,
|
weights=None,
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
img_size=416,
|
img_size=416,
|
||||||
|
@ -37,10 +37,10 @@ def test(cfg,
|
||||||
device = next(model.parameters()).device # get model device
|
device = next(model.parameters()).device # get model device
|
||||||
|
|
||||||
# Configure run
|
# Configure run
|
||||||
data_cfg = parse_data_cfg(data_cfg)
|
data = parse_data_cfg(data)
|
||||||
nc = int(data_cfg['classes']) # number of classes
|
nc = int(data['classes']) # number of classes
|
||||||
test_path = data_cfg['valid'] # path to test images
|
test_path = data['valid'] # path to test images
|
||||||
names = load_classes(data_cfg['names']) # class names
|
names = load_classes(data['names']) # class names
|
||||||
|
|
||||||
# Dataloader
|
# Dataloader
|
||||||
dataset = LoadImagesAndLabels(test_path, img_size, batch_size)
|
dataset = LoadImagesAndLabels(test_path, img_size, batch_size)
|
||||||
|
@ -203,7 +203,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
mAP = test(opt.cfg,
|
mAP = test(opt.cfg,
|
||||||
opt.data_cfg,
|
opt.data,
|
||||||
opt.weights,
|
opt.weights,
|
||||||
opt.batch_size,
|
opt.batch_size,
|
||||||
opt.img_size,
|
opt.img_size,
|
||||||
|
|
10
train.py
10
train.py
|
@ -62,7 +62,7 @@ hyp = {'giou': 1.153, # giou loss gain
|
||||||
|
|
||||||
|
|
||||||
def train(cfg,
|
def train(cfg,
|
||||||
data_cfg,
|
data,
|
||||||
img_size=416,
|
img_size=416,
|
||||||
epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs
|
epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs
|
||||||
batch_size=16,
|
batch_size=16,
|
||||||
|
@ -81,7 +81,7 @@ def train(cfg,
|
||||||
img_size = img_sz_max * 32 # initiate with maximum multi_scale size
|
img_size = img_sz_max * 32 # initiate with maximum multi_scale size
|
||||||
|
|
||||||
# Configure run
|
# Configure run
|
||||||
data_dict = parse_data_cfg(data_cfg)
|
data_dict = parse_data_cfg(data)
|
||||||
train_path = data_dict['train']
|
train_path = data_dict['train']
|
||||||
nc = int(data_dict['classes']) # number of classes
|
nc = int(data_dict['classes']) # number of classes
|
||||||
|
|
||||||
|
@ -270,7 +270,7 @@ def train(cfg,
|
||||||
# Calculate mAP (always test final epoch, skip first 5 if opt.nosave)
|
# Calculate mAP (always test final epoch, skip first 5 if opt.nosave)
|
||||||
if not (opt.notest or (opt.nosave and epoch < 10)) or epoch == epochs - 1:
|
if not (opt.notest or (opt.nosave and epoch < 10)) or epoch == epochs - 1:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
results, maps = test.test(cfg, data_cfg, batch_size=batch_size, img_size=opt.img_size, model=model,
|
results, maps = test.test(cfg, data, batch_size=batch_size, img_size=opt.img_size, model=model,
|
||||||
conf_thres=0.1)
|
conf_thres=0.1)
|
||||||
|
|
||||||
# Write epoch results
|
# Write epoch results
|
||||||
|
@ -361,7 +361,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# Train
|
# Train
|
||||||
results = train(opt.cfg,
|
results = train(opt.cfg,
|
||||||
opt.data_cfg,
|
opt.data,
|
||||||
img_size=opt.img_size,
|
img_size=opt.img_size,
|
||||||
epochs=opt.epochs,
|
epochs=opt.epochs,
|
||||||
batch_size=opt.batch_size,
|
batch_size=opt.batch_size,
|
||||||
|
@ -393,7 +393,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# Train mutation
|
# Train mutation
|
||||||
results = train(opt.cfg,
|
results = train(opt.cfg,
|
||||||
opt.data_cfg,
|
opt.data,
|
||||||
img_size=opt.img_size,
|
img_size=opt.img_size,
|
||||||
epochs=opt.epochs,
|
epochs=opt.epochs,
|
||||||
batch_size=opt.batch_size,
|
batch_size=opt.batch_size,
|
||||||
|
|
Loading…
Reference in New Issue