From f20a03e28ec72714893eceb04a06c9a04ffdbc68 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 10 Sep 2019 14:59:45 +0200 Subject: [PATCH] updates --- train.py | 3 ++- utils/torch_utils.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 9d6205f8..6379ddb9 100644 --- a/train.py +++ b/train.py @@ -381,11 +381,12 @@ if __name__ == '__main__': parser.add_argument('--arc', type=str, default='defaultpw', help='yolo architecture') # defaultpw, uCE, uBCE parser.add_argument('--prebias', action='store_true', help='transfer-learn yolo biases prior to training') parser.add_argument('--name', default='', help='renames results.txt to results_name.txt if supplied') + parser.add_argument('--device', default='', help='select device if multi-gpu, i.e. 0 or 0,1') parser.add_argument('--var', type=float, help='debug variable') opt = parser.parse_args() opt.weights = 'weights/last.pt' if opt.resume else opt.weights print(opt) - device = torch_utils.select_device(apex=mixed_precision) + device = torch_utils.select_device(opt.device, apex=mixed_precision) tb_writer = None if not opt.evolve: # Train normally diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 9bf97548..4805c23b 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -1,3 +1,4 @@ +import os import torch @@ -13,7 +14,11 @@ def init_seeds(seed=0): torch.backends.cudnn.benchmark = False -def select_device(force_cpu=False, apex=False): +def select_device(device=None, force_cpu=False, apex=False): + # Set environment variable if device is specified + if device: + os.environ['CUDA_VISIBLE_DEVICES'] = device + # apex if mixed precision training https://github.com/NVIDIA/apex cuda = False if force_cpu else torch.cuda.is_available() device = torch.device('cuda:0' if cuda else 'cpu')