diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 7d82a236..adbc8705 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -28,7 +28,7 @@ def select_device(force_cpu=False): x = [torch.cuda.get_device_properties(i) for i in range(ng)] cuda_str = 'Using CUDA ' + apex_str for i in range(0, ng): - if ng == 1: + if i == 1: # torch.cuda.set_device(0) # OPTIONAL: Set GPU ID cuda_str = ' ' * len(cuda_str) print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" %