diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 6b4b2624..11b3504e 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -17,12 +17,12 @@ def init_seeds(seed=0): def select_device(device=None, apex=False): if device == 'cpu': - force_cpu = True + pass elif device: # Set environment variable if device is specified os.environ['CUDA_VISIBLE_DEVICES'] = device # apex if mixed precision training https://github.com/NVIDIA/apex - cuda = False if force_cpu else torch.cuda.is_available() + cuda = False if device == 'cpu' else torch.cuda.is_available() device = torch.device('cuda:0' if cuda else 'cpu') if not cuda: