From df5060cbc807c05fae34c10e62fe6dd3a642e928 Mon Sep 17 00:00:00 2001 From: rickstaa Date: Tue, 8 Aug 2023 16:17:06 +0200 Subject: [PATCH] feat: add torch reproducibility code This commit adds extra code that can be uncommented to increase algorithm reproducibility over different machines. --- stable_learning_control/algos/pytorch/lac/lac.py | 2 ++ stable_learning_control/algos/pytorch/sac/sac.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/stable_learning_control/algos/pytorch/lac/lac.py b/stable_learning_control/algos/pytorch/lac/lac.py index e6eaf452e..4dc8ea117 100644 --- a/stable_learning_control/algos/pytorch/lac/lac.py +++ b/stable_learning_control/algos/pytorch/lac/lac.py @@ -1049,6 +1049,8 @@ def lac( os.environ["PYTHONHASHSEED"] = str( seed ) # Ensure python hashing is deterministic. + # torch.use_deterministic_algorithms(True) # Disable for reproducibility. + # torch.backends.cudnn.benchmark = False # Disable for reproducibility. policy = LAC( env, diff --git a/stable_learning_control/algos/pytorch/sac/sac.py b/stable_learning_control/algos/pytorch/sac/sac.py index c963a5578..f52fa102d 100644 --- a/stable_learning_control/algos/pytorch/sac/sac.py +++ b/stable_learning_control/algos/pytorch/sac/sac.py @@ -994,6 +994,8 @@ def sac( os.environ["PYTHONHASHSEED"] = str( seed ) # Ensure python hashing is deterministic. + # torch.use_deterministic_algorithms(True) # Disable for reproducibility. + # torch.backends.cudnn.benchmark = False # Disable for reproducibility. policy = SAC( env,