efficient calling test dataloader during training (#688)
* efficient calling test dataloader efficient calling test dataloader * efficient calling test dataloader during training efficient calling test dataloader during training * Update test.py * Update train.py * Update train.py
This commit is contained in:
parent
e27b124828
commit
7f6bb9a39f
5
test.py
5
test.py
|
@ -17,7 +17,9 @@ def test(cfg,
|
|||
conf_thres=0.001,
|
||||
nms_thres=0.5,
|
||||
save_json=False,
|
||||
model=None):
|
||||
model=None,
|
||||
names=None,
|
||||
dataloader=None):
|
||||
# Initialize/load model and set device
|
||||
if model is None:
|
||||
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
||||
|
@ -40,6 +42,7 @@ def test(cfg,
|
|||
verbose = False
|
||||
|
||||
# Configure run
|
||||
if (dataloader and names) is None:
|
||||
data = parse_data_cfg(data)
|
||||
nc = int(data['classes']) # number of classes
|
||||
test_path = data['valid'] # path to test images
|
||||
|
|
17
train.py
17
train.py
|
@ -73,6 +73,9 @@ def train():
|
|||
data_dict = parse_data_cfg(data)
|
||||
train_path = data_dict['train']
|
||||
nc = int(data_dict['classes']) # number of classes
|
||||
names = load_classes(data_dict['names'])
|
||||
|
||||
test_path = data_dict['valid'] # path to test images
|
||||
|
||||
# Remove previous results
|
||||
for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
|
||||
|
@ -197,6 +200,8 @@ def train():
|
|||
cache_labels=True if epochs > 10 else False,
|
||||
cache_images=False if opt.prebias else opt.cache_images)
|
||||
|
||||
dataset_test = LoadImagesAndLabels(test_path, img_size, batch_size)
|
||||
|
||||
# Dataloader
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]) # number of workers
|
||||
|
@ -207,6 +212,12 @@ def train():
|
|||
shuffle=not opt.rect, # Shuffle=True unless rectangular training is used
|
||||
pin_memory=True,
|
||||
collate_fn=dataset.collate_fn)
|
||||
dataloader_test = DataLoader(dataset_test,
|
||||
batch_size=batch_size,
|
||||
num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]),
|
||||
pin_memory=True,
|
||||
collate_fn=dataloader_test.collate_fn)
|
||||
|
||||
|
||||
# Start training
|
||||
model.nc = nc # attach number of classes to model
|
||||
|
@ -316,12 +327,14 @@ def train():
|
|||
if not (opt.notest or (opt.nosave and epoch < 10)) or final_epoch:
|
||||
with torch.no_grad():
|
||||
results, maps = test.test(cfg,
|
||||
data,
|
||||
data = None,
|
||||
batch_size=batch_size,
|
||||
img_size=opt.img_size,
|
||||
model=model,
|
||||
conf_thres=0.001 if final_epoch and epoch > 0 else 0.1, # 0.1 for speed
|
||||
save_json=final_epoch and epoch > 0 and 'coco.data' in data)
|
||||
save_json=final_epoch and epoch > 0 and 'coco.data' in data,
|
||||
names = names,
|
||||
dataloader = dataloader_test)
|
||||
|
||||
# Write epoch results
|
||||
with open(results_file, 'a') as f:
|
||||
|
|
Loading…
Reference in New Issue