Skip to content

Commit

Permalink
feat: add torch reproducibility code
Browse files Browse the repository at this point in the history
This commit adds extra code that can be uncommented to increase algorithm
reproducibility over different machines.
  • Loading branch information
rickstaa committed Aug 8, 2023
1 parent 14f9860 commit df5060c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,8 @@ def lac(
os.environ["PYTHONHASHSEED"] = str(
seed
) # Ensure python hashing is deterministic.
# torch.use_deterministic_algorithms(True) # Disable for reproducibility.
# torch.backends.cudnn.benchmark = False # Disable for reproducibility.

policy = LAC(
env,
Expand Down
2 changes: 2 additions & 0 deletions stable_learning_control/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,8 @@ def sac(
os.environ["PYTHONHASHSEED"] = str(
seed
) # Ensure python hashing is deterministic.
# torch.use_deterministic_algorithms(True) # Disable for reproducibility.
# torch.backends.cudnn.benchmark = False # Disable for reproducibility.

policy = SAC(
env,
Expand Down

0 comments on commit df5060c

Please sign in to comment.