diff --git a/utils/torch_utils.py b/utils/torch_utils.py index b984b265..a93b79d1 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -1,6 +1,7 @@ import os import torch +import torch.backends.cudnn as cudnn def init_seeds(seed=0): @@ -8,8 +9,8 @@ def init_seeds(seed=0): # Remove randomness (may be slower on Tesla GPUs) # https://pytorch.org/docs/stable/notes/randomness.html if seed == 0: - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False + cudnn.deterministic = True + cudnn.benchmark = False def select_device(device='', apex=False, batch_size=None):