diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 40eed4ae..02543118 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -6,7 +6,11 @@ def init_seeds(seed=0): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - # torch.backends.cudnn.deterministic = True # https://pytorch.org/docs/stable/notes/randomness.html + + # Remove randomness (may be slower on Tesla GPUs) # https://pytorch.org/docs/stable/notes/randomness.html + if seed == 0: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False def select_device(force_cpu=False, apex=False):