This commit is contained in:
Glenn Jocher 2019-09-26 12:52:16 +02:00
parent 2487b0694f
commit 163025649a
1 changed files with 3 additions and 2 deletions

View File

@ -14,10 +14,11 @@ def init_seeds(seed=0):
torch.backends.cudnn.benchmark = False
def select_device(device=None, apex=False):
if device == 'cpu':
def select_device(device='', apex=False):
if device.lower() == 'cpu':
pass
elif device: # Set environment variable if device is specified
assert torch.cuda.is_available(), 'CUDA unavailable, invalid device %s requested' % device
os.environ['CUDA_VISIBLE_DEVICES'] = device
# apex if mixed precision training https://github.com/NVIDIA/apex