diff --git a/test.py b/test.py index 92c6c948..d7a9fbf5 100644 --- a/test.py +++ b/test.py @@ -17,7 +17,9 @@ def test(cfg, conf_thres=0.001, nms_thres=0.5, save_json=False, - model=None): + model=None, + names=None, + dataloader=None): # Initialize/load model and set device if model is None: device = torch_utils.select_device(opt.device, batch_size=batch_size) @@ -40,15 +42,16 @@ def test(cfg, verbose = False # Configure run - data = parse_data_cfg(data) - nc = int(data['classes']) # number of classes - test_path = data['valid'] # path to test images - names = load_classes(data['names']) # class names + if (dataloader and names) is None: + data = parse_data_cfg(data) + nc = int(data['classes']) # number of classes + test_path = data['valid'] # path to test images + names = load_classes(data['names']) # class names - # Dataloader - dataset = LoadImagesAndLabels(test_path, img_size, batch_size) - batch_size = min(batch_size, len(dataset)) - dataloader = DataLoader(dataset, + # Dataloader + dataset = LoadImagesAndLabels(test_path, img_size, batch_size) + batch_size = min(batch_size, len(dataset)) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]), pin_memory=True, diff --git a/train.py b/train.py index 9aa6f683..0ec33b86 100644 --- a/train.py +++ b/train.py @@ -73,7 +73,10 @@ def train(): data_dict = parse_data_cfg(data) train_path = data_dict['train'] nc = int(data_dict['classes']) # number of classes - + names = load_classes(data_dict['names']) + + test_path = data_dict['valid'] # path to test images + # Remove previous results for f in glob.glob('*_batch*.jpg') + glob.glob(results_file): os.remove(f) @@ -196,7 +199,9 @@ def train(): image_weights=opt.img_weights, cache_labels=True if epochs > 10 else False, cache_images=False if opt.prebias else opt.cache_images) - + + dataset_test = LoadImagesAndLabels(test_path, img_size, batch_size) + # Dataloader batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]) # number of workers @@ -207,6 +212,12 @@ def train(): shuffle=not opt.rect, # Shuffle=True unless rectangular training is used pin_memory=True, collate_fn=dataset.collate_fn) + dataloader_test = DataLoader(dataset_test, + batch_size=batch_size, + num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 16]), + pin_memory=True, + collate_fn=dataloader_test.collate_fn) + # Start training model.nc = nc # attach number of classes to model @@ -316,12 +327,14 @@ def train(): if not (opt.notest or (opt.nosave and epoch < 10)) or final_epoch: with torch.no_grad(): results, maps = test.test(cfg, - data, + data = None, batch_size=batch_size, img_size=opt.img_size, model=model, conf_thres=0.001 if final_epoch and epoch > 0 else 0.1, # 0.1 for speed - save_json=final_epoch and epoch > 0 and 'coco.data' in data) + save_json=final_epoch and epoch > 0 and 'coco.data' in data, + names = names, + dataloader = dataloader_test) # Write epoch results with open(results_file, 'a') as f: