Skip to content

Commit

Permalink
Fix/determinism (#231)
Browse files Browse the repository at this point in the history
* Add all torch reproducible settings

* Add cublas_workspace_config
  • Loading branch information
belerico authored Mar 12, 2024
1 parent 07f60d7 commit afee2cd
Show file tree
Hide file tree
Showing 18 changed files with 48 additions and 38 deletions.
2 changes: 0 additions & 2 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

if cfg.checkpoint.resume_from:
state = fabric.load(cfg.checkpoint.resume_from)
Expand Down
3 changes: 1 addition & 2 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Dreamer-V2 implementation from [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193).
Adapted from the original implementation from https://github.com/danijar/dreamerv2
"""

from __future__ import annotations

import copy
Expand Down Expand Up @@ -419,8 +420,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

if cfg.checkpoint.resume_from:
state = fabric.load(cfg.checkpoint.resume_from)
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

if cfg.checkpoint.resume_from:
state = fabric.load(cfg.checkpoint.resume_from)
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

if cfg.checkpoint.resume_from:
state = fabric.load(cfg.checkpoint.resume_from)
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

ckpt_path = pathlib.Path(cfg.checkpoint.exploration_ckpt_path)
resume_from_checkpoint = cfg.checkpoint.resume_from is not None
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

if cfg.checkpoint.resume_from:
state = fabric.load(cfg.checkpoint.resume_from)
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

ckpt_path = pathlib.Path(cfg.checkpoint.exploration_ckpt_path)
resume_from_checkpoint = cfg.checkpoint.resume_from is not None
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

if cfg.checkpoint.resume_from:
state = fabric.load(cfg.checkpoint.resume_from)
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

ckpt_path = pathlib.Path(cfg.checkpoint.exploration_ckpt_path)
resume_from_checkpoint = cfg.checkpoint.resume_from is not None
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
rank = fabric.global_rank
world_size = fabric.world_size
device = fabric.device
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
4 changes: 0 additions & 4 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def player(
# Initialize the fabric object
log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name, False)
device = fabric.device
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down Expand Up @@ -377,8 +375,6 @@ def trainer(
)
fabric.launch()
device = fabric.device
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
rank = fabric.global_rank
world_size = fabric.world_size
device = fabric.device
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
4 changes: 0 additions & 4 deletions sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def player(
log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name, False)
rank = fabric.global_rank
device = fabric.device
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down Expand Up @@ -352,8 +350,6 @@ def trainer(
)
fabric.launch()
device = fabric.device
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/sac_ae/sac_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
19 changes: 18 additions & 1 deletion sheeprl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,24 @@ def run_algorithm(cfg: Dict[str, Any]):
for k in keys_to_remove:
cfg.model_manager.models.pop(k, None)
cfg.model_manager.disabled == cfg.model_manager.disabled or len(cfg.model_manager.models) == 0
fabric.launch(command, cfg, **kwargs)

# This function is used to make the algorithm reproducible.
# It can be an overkill since Fabric already captures everything we're setting here
# when multiprocessing is used with a `spawn` method (default with DDP strategy).
# https://github.com/Lightning-AI/pytorch-lightning/blob/f23b3b1e7fdab1d325f79f69a28706d33144f27e/src/lightning/fabric/strategies/launchers/multiprocessing.py#L112
def reproducible(func):
def wrapper(fabric: Fabric, cfg: Dict[str, Any], *args, **kwargs):
if cfg.cublas_workspace_config is not None:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = cfg.cublas_workspace_config
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.benchmark = cfg.torch_backends_cudnn_benchmark
torch.backends.cudnn.deterministic = cfg.torch_backends_cudnn_deterministic
torch.use_deterministic_algorithms(cfg.torch_use_deterministic_algorithms)
return func(fabric, cfg, *args, **kwargs)

return wrapper

fabric.launch(reproducible(command), cfg, **kwargs)


def eval_algorithm(cfg: DictConfig):
Expand Down
30 changes: 29 additions & 1 deletion sheeprl/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,35 @@ dry_run: False

# Reproducibility
seed: 42
torch_deterministic: False

# For more information about reproducibility in PyTorch, see https://pytorch.org/docs/stable/notes/randomness.html

# torch.use_deterministic_algorithms() lets you configure PyTorch to use deterministic algorithms
# instead of nondeterministic ones where available,
# and to throw an error if an operation is known to be nondeterministic (and without a deterministic alternative).
torch_use_deterministic_algorithms: False

# Disabling the benchmarking feature with torch.backends.cudnn.benchmark = False
# causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced performance.
# However, if you do not need reproducibility across multiple executions of your application,
# then performance might improve if the benchmarking feature is enabled with torch.backends.cudnn.benchmark = True.
torch_backends_cudnn_benchmark: True

# While disabling CUDA convolution benchmarking (discussed above) ensures that CUDA selects the same algorithm each time an application is run,
# that algorithm itself may be nondeterministic, unless either torch.use_deterministic_algorithms(True)
# or torch.backends.cudnn.deterministic = True is set.
# The latter setting controls only this behavior,
# unlike torch.use_deterministic_algorithms() which will make other PyTorch operations behave deterministically, too.
torch_backends_cudnn_deterministic: False

# From: https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
# By design, all cuBLAS API routines from a given toolkit version, generate the same bit-wise results at every run
# when executed on GPUs with the same architecture and the same number of SMs.
# However, bit-wise reproducibility is not guaranteed across toolkit versions
# because the implementation might differ due to some implementation changes.
# This guarantee holds when a single CUDA stream is active only.
# If multiple concurrent streams are active, the library may optimize total performance by picking different internal implementations.
cublas_workspace_config: null # Possible values are: ":4096:8" or ":16:8"

# Output folders
exp_name: ${algo.name}_${env.id}
Expand Down

0 comments on commit afee2cd

Please sign in to comment.