updates
This commit is contained in:
parent
39f63b7110
commit
deb200f6bf
|
@ -8,7 +8,7 @@ from utils.utils import *
|
|||
|
||||
|
||||
def detect(cfg,
|
||||
data_cfg,
|
||||
data,
|
||||
weights,
|
||||
images='data/samples', # input folder
|
||||
output='output', # output folder
|
||||
|
@ -59,7 +59,7 @@ def detect(cfg,
|
|||
dataloader = LoadImages(images, img_size=img_size)
|
||||
|
||||
# Get classes and colors
|
||||
classes = load_classes(parse_data_cfg(data_cfg)['names'])
|
||||
classes = load_classes(parse_data_cfg(data)['names'])
|
||||
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(classes))]
|
||||
|
||||
for i, (path, img, im0, vid_cap) in enumerate(dataloader):
|
||||
|
@ -133,7 +133,7 @@ if __name__ == '__main__':
|
|||
|
||||
with torch.no_grad():
|
||||
detect(opt.cfg,
|
||||
opt.data_cfg,
|
||||
opt.data,
|
||||
opt.weights,
|
||||
images=opt.images,
|
||||
img_size=opt.img_size,
|
||||
|
|
12
test.py
12
test.py
|
@ -9,7 +9,7 @@ from utils.utils import *
|
|||
|
||||
|
||||
def test(cfg,
|
||||
data_cfg,
|
||||
data,
|
||||
weights=None,
|
||||
batch_size=16,
|
||||
img_size=416,
|
||||
|
@ -37,10 +37,10 @@ def test(cfg,
|
|||
device = next(model.parameters()).device # get model device
|
||||
|
||||
# Configure run
|
||||
data_cfg = parse_data_cfg(data_cfg)
|
||||
nc = int(data_cfg['classes']) # number of classes
|
||||
test_path = data_cfg['valid'] # path to test images
|
||||
names = load_classes(data_cfg['names']) # class names
|
||||
data = parse_data_cfg(data)
|
||||
nc = int(data['classes']) # number of classes
|
||||
test_path = data['valid'] # path to test images
|
||||
names = load_classes(data['names']) # class names
|
||||
|
||||
# Dataloader
|
||||
dataset = LoadImagesAndLabels(test_path, img_size, batch_size)
|
||||
|
@ -203,7 +203,7 @@ if __name__ == '__main__':
|
|||
|
||||
with torch.no_grad():
|
||||
mAP = test(opt.cfg,
|
||||
opt.data_cfg,
|
||||
opt.data,
|
||||
opt.weights,
|
||||
opt.batch_size,
|
||||
opt.img_size,
|
||||
|
|
10
train.py
10
train.py
|
@ -62,7 +62,7 @@ hyp = {'giou': 1.153, # giou loss gain
|
|||
|
||||
|
||||
def train(cfg,
|
||||
data_cfg,
|
||||
data,
|
||||
img_size=416,
|
||||
epochs=100, # 500200 batches at bs 16, 117263 images = 273 epochs
|
||||
batch_size=16,
|
||||
|
@ -81,7 +81,7 @@ def train(cfg,
|
|||
img_size = img_sz_max * 32 # initiate with maximum multi_scale size
|
||||
|
||||
# Configure run
|
||||
data_dict = parse_data_cfg(data_cfg)
|
||||
data_dict = parse_data_cfg(data)
|
||||
train_path = data_dict['train']
|
||||
nc = int(data_dict['classes']) # number of classes
|
||||
|
||||
|
@ -270,7 +270,7 @@ def train(cfg,
|
|||
# Calculate mAP (always test final epoch, skip first 5 if opt.nosave)
|
||||
if not (opt.notest or (opt.nosave and epoch < 10)) or epoch == epochs - 1:
|
||||
with torch.no_grad():
|
||||
results, maps = test.test(cfg, data_cfg, batch_size=batch_size, img_size=opt.img_size, model=model,
|
||||
results, maps = test.test(cfg, data, batch_size=batch_size, img_size=opt.img_size, model=model,
|
||||
conf_thres=0.1)
|
||||
|
||||
# Write epoch results
|
||||
|
@ -361,7 +361,7 @@ if __name__ == '__main__':
|
|||
|
||||
# Train
|
||||
results = train(opt.cfg,
|
||||
opt.data_cfg,
|
||||
opt.data,
|
||||
img_size=opt.img_size,
|
||||
epochs=opt.epochs,
|
||||
batch_size=opt.batch_size,
|
||||
|
@ -393,7 +393,7 @@ if __name__ == '__main__':
|
|||
|
||||
# Train mutation
|
||||
results = train(opt.cfg,
|
||||
opt.data_cfg,
|
||||
opt.data,
|
||||
img_size=opt.img_size,
|
||||
epochs=opt.epochs,
|
||||
batch_size=opt.batch_size,
|
||||
|
|
Loading…
Reference in New Issue