diff --git a/utils/torch_utils.py b/utils/torch_utils.py index aa42cb0f..ef87d6ca 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -14,10 +14,11 @@ def init_seeds(seed=0): torch.backends.cudnn.benchmark = False -def select_device(device=None, apex=False): - if device == 'cpu': +def select_device(device='', apex=False): + if device.lower() == 'cpu': pass elif device: # Set environment variable if device is specified + assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device os.environ['CUDA_VISIBLE_DEVICES'] = device # apex if mixed precision training https://github.com/NVIDIA/apex