diff --git a/utils/torch_utils.py b/utils/torch_utils.py index b58be1cb..90469ab0 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -5,7 +5,6 @@ def init_seeds(seed=0): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.benchmark = True # set False for reproducible resuls # torch.backends.cudnn.deterministic = True # https://pytorch.org/docs/stable/notes/randomness.html @@ -16,6 +15,7 @@ def select_device(force_cpu=False): if not cuda: print('Using CPU') if cuda: + torch.backends.cudnn.benchmark = True # set False for reproducible results c = 1024 ** 2 # bytes to MB ng = torch.cuda.device_count() x = [torch.cuda.get_device_properties(i) for i in range(ng)]