efficient calling test dataloader during training (#688)
* efficient calling test dataloader efficient calling test dataloader * efficient calling test dataloader during training efficient calling test dataloader during training * Update test.py * Update train.py * Update train.py
This commit is contained in:
parent
e27b124828
commit
7f6bb9a39f
21
test.py
21
test.py
|
@ -17,7 +17,9 @@ def test(cfg,
|
||||||
conf_thres=0.001,
|
conf_thres=0.001,
|
||||||
nms_thres=0.5,
|
nms_thres=0.5,
|
||||||
save_json=False,
|
save_json=False,
|
||||||
model=None):
|
model=None,
|
||||||
|
names=None,
|
||||||
|
dataloader=None):
|
||||||
# Initialize/load model and set device
|
# Initialize/load model and set device
|
||||||
if model is None:
|
if model is None:
|
||||||
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
||||||
|
@ -40,15 +42,16 @@ def test(cfg,
|
||||||
verbose = False
|
verbose = False
|
||||||
|
|
||||||
# Configure run
|
# Configure run
|
||||||
data = parse_data_cfg(data)
|
if (dataloader and names) is None:
|
||||||
nc = int(data['classes']) # number of classes
|
data = parse_data_cfg(data)
|
||||||
test_path = data['valid'] # path to test images
|
nc = int(data['classes']) # number of classes
|
||||||
names = load_classes(data['names']) # class names
|
test_path = data['valid'] # path to test images
|
||||||
|
names = load_classes(data['names']) # class names
|
||||||
|
|
||||||
# Dataloader
|
# Dataloader
|
||||||
dataset = LoadImagesAndLabels(test_path, img_size, batch_size)
|
dataset = LoadImagesAndLabels(test_path, img_size, batch_size)
|
||||||
batch_size = min(batch_size, len(dataset))
|
batch_size = min(batch_size, len(dataset))
|
||||||
dataloader = DataLoader(dataset,
|
dataloader = DataLoader(dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]),
|
num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]),
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
|
|
21
train.py
21
train.py
|
@ -73,7 +73,10 @@ def train():
|
||||||
data_dict = parse_data_cfg(data)
|
data_dict = parse_data_cfg(data)
|
||||||
train_path = data_dict['train']
|
train_path = data_dict['train']
|
||||||
nc = int(data_dict['classes']) # number of classes
|
nc = int(data_dict['classes']) # number of classes
|
||||||
|
names = load_classes(data_dict['names'])
|
||||||
|
|
||||||
|
test_path = data_dict['valid'] # path to test images
|
||||||
|
|
||||||
# Remove previous results
|
# Remove previous results
|
||||||
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
|
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
|
||||||
os.remove(f)
|
os.remove(f)
|
||||||
|
@ -196,7 +199,9 @@ def train():
|
||||||
image_weights=opt.img_weights,
|
image_weights=opt.img_weights,
|
||||||
cache_labels=True if epochs > 10 else False,
|
cache_labels=True if epochs > 10 else False,
|
||||||
cache_images=False if opt.prebias else opt.cache_images)
|
cache_images=False if opt.prebias else opt.cache_images)
|
||||||
|
|
||||||
|
dataset_test = LoadImagesAndLabels(test_path, img_size, batch_size)
|
||||||
|
|
||||||
# Dataloader
|
# Dataloader
|
||||||
batch_size = min(batch_size, len(dataset))
|
batch_size = min(batch_size, len(dataset))
|
||||||
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]) # number of workers
|
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]) # number of workers
|
||||||
|
@ -207,6 +212,12 @@ def train():
|
||||||
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
|
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=dataset.collate_fn)
|
collate_fn=dataset.collate_fn)
|
||||||
|
dataloader_test = DataLoader(dataset_test,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]),
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=dataloader_test.collate_fn)
|
||||||
|
|
||||||
|
|
||||||
# Start training
|
# Start training
|
||||||
model.nc = nc # attach number of classes to model
|
model.nc = nc # attach number of classes to model
|
||||||
|
@ -316,12 +327,14 @@ def train():
|
||||||
if not (opt.notest or (opt.nosave and epoch < 10)) or final_epoch:
|
if not (opt.notest or (opt.nosave and epoch < 10)) or final_epoch:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
results, maps = test.test(cfg,
|
results, maps = test.test(cfg,
|
||||||
data,
|
data = None,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
img_size=opt.img_size,
|
img_size=opt.img_size,
|
||||||
model=model,
|
model=model,
|
||||||
conf_thres=0.001 if final_epoch and epoch > 0 else 0.1, # 0.1 for speed
|
conf_thres=0.001 if final_epoch and epoch > 0 else 0.1, # 0.1 for speed
|
||||||
save_json=final_epoch and epoch > 0 and 'coco.data' in data)
|
save_json=final_epoch and epoch > 0 and 'coco.data' in data,
|
||||||
|
names = names,
|
||||||
|
dataloader = dataloader_test)
|
||||||
|
|
||||||
# Write epoch results
|
# Write epoch results
|
||||||
with open(results_file, 'a') as f:
|
with open(results_file, 'a') as f:
|
||||||
|
|
Loading…
Reference in New Issue