Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/determinism #231

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

if cfg.checkpoint.resume_from:
state = fabric.load(cfg.checkpoint.resume_from)
Expand Down
3 changes: 1 addition & 2 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Dreamer-V2 implementation from [https://arxiv.org/abs/2010.02193](https://arxiv.org/abs/2010.02193).
Adapted from the original implementation from https://github.com/danijar/dreamerv2
"""

from __future__ import annotations

import copy
Expand Down Expand Up @@ -419,8 +420,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

if cfg.checkpoint.resume_from:
state = fabric.load(cfg.checkpoint.resume_from)
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

if cfg.checkpoint.resume_from:
state = fabric.load(cfg.checkpoint.resume_from)
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

if cfg.checkpoint.resume_from:
state = fabric.load(cfg.checkpoint.resume_from)
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

ckpt_path = pathlib.Path(cfg.checkpoint.exploration_ckpt_path)
resume_from_checkpoint = cfg.checkpoint.resume_from is not None
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

if cfg.checkpoint.resume_from:
state = fabric.load(cfg.checkpoint.resume_from)
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

ckpt_path = pathlib.Path(cfg.checkpoint.exploration_ckpt_path)
resume_from_checkpoint = cfg.checkpoint.resume_from is not None
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

if cfg.checkpoint.resume_from:
state = fabric.load(cfg.checkpoint.resume_from)
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

ckpt_path = pathlib.Path(cfg.checkpoint.exploration_ckpt_path)
resume_from_checkpoint = cfg.checkpoint.resume_from is not None
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
rank = fabric.global_rank
world_size = fabric.world_size
device = fabric.device
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
4 changes: 0 additions & 4 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def player(
# Initialize the fabric object
log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name, False)
device = fabric.device
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down Expand Up @@ -377,8 +375,6 @@ def trainer(
)
fabric.launch()
device = fabric.device
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
rank = fabric.global_rank
world_size = fabric.world_size
device = fabric.device
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
4 changes: 0 additions & 4 deletions sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def player(
log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name, False)
rank = fabric.global_rank
device = fabric.device
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down Expand Up @@ -352,8 +350,6 @@ def trainer(
)
fabric.launch()
device = fabric.device
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/sac_ae/sac_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
device = fabric.device
rank = fabric.global_rank
world_size = fabric.world_size
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.deterministic = cfg.torch_deterministic

# Resume from checkpoint
if cfg.checkpoint.resume_from:
Expand Down
19 changes: 18 additions & 1 deletion sheeprl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,24 @@ def run_algorithm(cfg: Dict[str, Any]):
for k in keys_to_remove:
cfg.model_manager.models.pop(k, None)
cfg.model_manager.disabled == cfg.model_manager.disabled or len(cfg.model_manager.models) == 0
fabric.launch(command, cfg, **kwargs)

# This function is used to make the algorithm reproducible.
# It can be an overkill since Fabric already captures everything we're setting here
# when multiprocessing is used with a `spawn` method (default with DDP strategy).
# https://github.com/Lightning-AI/pytorch-lightning/blob/f23b3b1e7fdab1d325f79f69a28706d33144f27e/src/lightning/fabric/strategies/launchers/multiprocessing.py#L112
def reproducible(func):
def wrapper(fabric: Fabric, cfg: Dict[str, Any], *args, **kwargs):
if cfg.cublas_workspace_config is not None:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = cfg.cublas_workspace_config
fabric.seed_everything(cfg.seed)
torch.backends.cudnn.benchmark = cfg.torch_backends_cudnn_benchmark
torch.backends.cudnn.deterministic = cfg.torch_backends_cudnn_deterministic
torch.use_deterministic_algorithms(cfg.torch_use_deterministic_algorithms)
return func(fabric, cfg, *args, **kwargs)

return wrapper

fabric.launch(reproducible(command), cfg, **kwargs)


def eval_algorithm(cfg: DictConfig):
Expand Down
30 changes: 29 additions & 1 deletion sheeprl/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,35 @@ dry_run: False

# Reproducibility
seed: 42
torch_deterministic: False

# For more information about reproducibility in PyTorch, see https://pytorch.org/docs/stable/notes/randomness.html

# torch.use_deterministic_algorithms() lets you configure PyTorch to use deterministic algorithms
# instead of nondeterministic ones where available,
# and to throw an error if an operation is known to be nondeterministic (and without a deterministic alternative).
torch_use_deterministic_algorithms: False

# Disabling the benchmarking feature with torch.backends.cudnn.benchmark = False
# causes cuDNN to deterministically select an algorithm, possibly at the cost of reduced performance.
# However, if you do not need reproducibility across multiple executions of your application,
# then performance might improve if the benchmarking feature is enabled with torch.backends.cudnn.benchmark = True.
torch_backends_cudnn_benchmark: True

# While disabling CUDA convolution benchmarking (discussed above) ensures that CUDA selects the same algorithm each time an application is run,
# that algorithm itself may be nondeterministic, unless either torch.use_deterministic_algorithms(True)
# or torch.backends.cudnn.deterministic = True is set.
# The latter setting controls only this behavior,
# unlike torch.use_deterministic_algorithms() which will make other PyTorch operations behave deterministically, too.
torch_backends_cudnn_deterministic: False

# From: https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility
# By design, all cuBLAS API routines from a given toolkit version, generate the same bit-wise results at every run
# when executed on GPUs with the same architecture and the same number of SMs.
# However, bit-wise reproducibility is not guaranteed across toolkit versions
# because the implementation might differ due to some implementation changes.
# This guarantee holds when a single CUDA stream is active only.
# If multiple concurrent streams are active, the library may optimize total performance by picking different internal implementations.
cublas_workspace_config: null # Possible values are: ":4096:8" or ":16:8"

# Output folders
exp_name: ${algo.name}_${env.id}
Expand Down
Loading