diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 58bf5ff4..11a09627 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -13,6 +13,7 @@ def init_seeds(seed=0): if CUDA_AVAILABLE: torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) + # torch.cuda.set_device(0) # OPTIONAL: Set your GPU if multiple available def select_device(force_cpu=False):