efficient calling test dataloader during training (#688)

* efficient calling test dataloader

efficient calling test dataloader

* efficient calling test dataloader during training

efficient calling test dataloader during training

* Update test.py

* Update train.py

* Update train.py
This commit is contained in:
Yonghye Kwon 2019-12-05 16:02:10 +09:00 committed by Glenn Jocher
parent e27b124828
commit 7f6bb9a39f
2 changed files with 29 additions and 13 deletions

21
test.py
View File

@ -17,7 +17,9 @@ def test(cfg,
conf_thres=0.001, conf_thres=0.001,
nms_thres=0.5, nms_thres=0.5,
save_json=False, save_json=False,
model=None): model=None,
names=None,
dataloader=None):
# Initialize/load model and set device # Initialize/load model and set device
if model is None: if model is None:
device = torch_utils.select_device(opt.device, batch_size=batch_size) device = torch_utils.select_device(opt.device, batch_size=batch_size)
@ -40,15 +42,16 @@ def test(cfg,
verbose = False verbose = False
# Configure run # Configure run
data = parse_data_cfg(data) if (dataloader and names) is None:
nc = int(data['classes']) # number of classes data = parse_data_cfg(data)
test_path = data['valid'] # path to test images nc = int(data['classes']) # number of classes
names = load_classes(data['names']) # class names test_path = data['valid'] # path to test images
names = load_classes(data['names']) # class names
# Dataloader # Dataloader
dataset = LoadImagesAndLabels(test_path, img_size, batch_size) dataset = LoadImagesAndLabels(test_path, img_size, batch_size)
batch_size = min(batch_size, len(dataset)) batch_size = min(batch_size, len(dataset))
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
batch_size=batch_size, batch_size=batch_size,
num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]), num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]),
pin_memory=True, pin_memory=True,

View File

@ -73,7 +73,10 @@ def train():
data_dict = parse_data_cfg(data) data_dict = parse_data_cfg(data)
train_path = data_dict['train'] train_path = data_dict['train']
nc = int(data_dict['classes']) # number of classes nc = int(data_dict['classes']) # number of classes
names = load_classes(data_dict['names'])
test_path = data_dict['valid'] # path to test images
# Remove previous results # Remove previous results
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file): for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
os.remove(f) os.remove(f)
@ -196,7 +199,9 @@ def train():
image_weights=opt.img_weights, image_weights=opt.img_weights,
cache_labels=True if epochs > 10 else False, cache_labels=True if epochs > 10 else False,
cache_images=False if opt.prebias else opt.cache_images) cache_images=False if opt.prebias else opt.cache_images)
dataset_test = LoadImagesAndLabels(test_path, img_size, batch_size)
# Dataloader # Dataloader
batch_size = min(batch_size, len(dataset)) batch_size = min(batch_size, len(dataset))
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]) # number of workers nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]) # number of workers
@ -207,6 +212,12 @@ def train():
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
pin_memory=True, pin_memory=True,
collate_fn=dataset.collate_fn) collate_fn=dataset.collate_fn)
dataloader_test = DataLoader(dataset_test,
batch_size=batch_size,
num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]),
pin_memory=True,
collate_fn=dataloader_test.collate_fn)
# Start training # Start training
model.nc = nc # attach number of classes to model model.nc = nc # attach number of classes to model
@ -316,12 +327,14 @@ def train():
if not (opt.notest or (opt.nosave and epoch < 10)) or final_epoch: if not (opt.notest or (opt.nosave and epoch < 10)) or final_epoch:
with torch.no_grad(): with torch.no_grad():
results, maps = test.test(cfg, results, maps = test.test(cfg,
data, data = None,
batch_size=batch_size, batch_size=batch_size,
img_size=opt.img_size, img_size=opt.img_size,
model=model, model=model,
conf_thres=0.001 if final_epoch and epoch > 0 else 0.1, # 0.1 for speed conf_thres=0.001 if final_epoch and epoch > 0 else 0.1, # 0.1 for speed
save_json=final_epoch and epoch > 0 and 'coco.data' in data) save_json=final_epoch and epoch > 0 and 'coco.data' in data,
names = names,
dataloader = dataloader_test)
# Write epoch results # Write epoch results
with open(results_file, 'a') as f: with open(results_file, 'a') as f: