From 5a566454f5efc4ce424c84826358b5de602d7aad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Garc=C3=ADa?= Date: Wed, 5 Dec 2018 11:55:27 +0100 Subject: [PATCH] Extract seed and cuda initialization utils --- detect.py | 6 ++++-- test.py | 11 ++++++++--- train.py | 20 ++++++++++---------- utils/torch_utils.py | 23 +++++++++++++++++++++++ utils/utils.py | 8 ++++++++ 5 files changed, 53 insertions(+), 15 deletions(-) create mode 100644 utils/torch_utils.py diff --git a/detect.py b/detect.py index 7953f30e..cc2ef727 100755 --- a/detect.py +++ b/detect.py @@ -5,8 +5,6 @@ from models import * from utils.datasets import * from utils.utils import * -cuda = torch.cuda.is_available() -device = torch.device('cuda:0' if cuda else 'cpu') f_path = os.path.dirname(os.path.realpath(__file__)) + '/' parser = argparse.ArgumentParser() @@ -28,6 +26,10 @@ print(opt) def main(opt): + + device = torch_utils.select_device() + print("Using device: \"{}\"".format(device)) + os.system('rm -rf ' + opt.output_folder) os.makedirs(opt.output_folder, exist_ok=True) diff --git a/test.py b/test.py index f65d373c..c0cf476b 100644 --- a/test.py +++ b/test.py @@ -4,6 +4,8 @@ from models import * from utils.datasets import * from utils.utils import * +from utils import torch_utils + parser = argparse.ArgumentParser(prog='test.py') parser.add_argument('-batch_size', type=int, default=32, help='size of each image batch') parser.add_argument('-cfg', type=str, default='cfg/yolov3.cfg', help='path to model config file') @@ -18,11 +20,11 @@ parser.add_argument('-img_size', type=int, default=416, help='size of each image opt = parser.parse_args() print(opt, end='\n\n') -cuda = torch.cuda.is_available() -device = torch.device('cuda:0' if cuda else 'cpu') - def main(opt): + device = torch_utils.select_device() + print("Using device: \"{}\"".format(device)) + # Configure run data_config = parse_data_config(opt.data_config_path) nC = int(data_config['classes']) # number of classes (80 for COCO) @@ -128,4 +130,7 @@ def main(opt): if __name__ == '__main__': + + init_seeds() + mAP = main(opt) diff --git a/train.py b/train.py index 93ac9c96..103d6f03 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,8 @@ from models import * from utils.datasets import * from utils.utils import * +from utils import torch_utils + parser = argparse.ArgumentParser() parser.add_argument('-epochs', type=int, default=100, help='number of epochs') parser.add_argument('-batch_size', type=int, default=16, help='size of each image batch') @@ -26,20 +28,15 @@ print(opt) sys.argv[1:] = [] # delete any train.py command-line arguments before they reach test.py import test # must follow sys.argv[1:] = [] -cuda = torch.cuda.is_available() -device = torch.device('cuda:0' if cuda else 'cpu') -random.seed(0) -np.random.seed(0) -torch.manual_seed(0) -if cuda: - torch.cuda.manual_seed(0) - torch.cuda.manual_seed_all(0) +def main(opt): + + device = torch_utils.select_device() + print("Using device: \"{}\"".format(device)) + if not opt.multi_scale: torch.backends.cudnn.benchmark = True - -def main(opt): os.makedirs('weights', exist_ok=True) # Configure run @@ -217,5 +214,8 @@ def main(opt): if __name__ == '__main__': + + init_seeds() + torch.cuda.empty_cache() main(opt) diff --git a/utils/torch_utils.py b/utils/torch_utils.py new file mode 100644 index 00000000..58bf5ff4 --- /dev/null +++ b/utils/torch_utils.py @@ -0,0 +1,23 @@ +import torch + + +def check_cuda(): + return torch.cuda.is_available() + + +CUDA_AVAILABLE = check_cuda() + + +def init_seeds(seed=0): + torch.manual_seed(seed) + if CUDA_AVAILABLE: + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def select_device(force_cpu=False): + if force_cpu: + device = torch.device('cpu') + else: + device = torch.device('cuda:0' if CUDA_AVAILABLE else 'cpu') + return device diff --git a/utils/utils.py b/utils/utils.py index 6fcf5fac..12d161bd 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -5,11 +5,19 @@ import numpy as np import torch import torch.nn.functional as F +from utils import torch_utils + # Set printoptions torch.set_printoptions(linewidth=1320, precision=5, profile='long') np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 +def init_seeds(seed=0): + random.seed(seed) + np.random.seed(seed) + torch_utils.init_seeds(seed=seed) + + def load_classes(path): """ Loads class labels at 'path'