From 163025649a3f406c2ebcf699c91183b1112a3ba8 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 26 Sep 2019 12:52:16 +0200 Subject: [PATCH] updates --- utils/torch_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index aa42cb0f..ef87d6ca 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -14,10 +14,11 @@ def init_seeds(seed=0): torch.backends.cudnn.benchmark = False -def select_device(device=None, apex=False): - if device == 'cpu': +def select_device(device='', apex=False): + if device.lower() == 'cpu': pass elif device: # Set environment variable if device is specified + assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device os.environ['CUDA_VISIBLE_DEVICES'] = device # apex if mixed precision training https://github.com/NVIDIA/apex