From 95173b2f113e2d715d6e876172fec44584349848 Mon Sep 17 00:00:00 2001 From: Howard Liberty Date: Tue, 14 May 2024 14:48:22 -0700 Subject: [PATCH 1/2] Enable config for fsdp activation checkpointing --- src/accelerate/commands/config/cluster.py | 6 ++++++ src/accelerate/utils/launch.py | 1 + 2 files changed, 7 insertions(+) diff --git a/src/accelerate/commands/config/cluster.py b/src/accelerate/commands/config/cluster.py index 44fe8af32da..26e33344d01 100644 --- a/src/accelerate/commands/config/cluster.py +++ b/src/accelerate/commands/config/cluster.py @@ -446,6 +446,12 @@ def get_cluster_input(): default=True, error_message="Please enter yes or no.", ) + fsdp_config["fsdp_activation_checkpointing"] = _ask_field( + "Do you want to enable FSDP activation checkpointing? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) megatron_lm_config = {} if distributed_type in [DistributedType.MULTI_GPU]: diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index 1932fe0592a..4189aacb63a 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -250,6 +250,7 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]: current_env["FSDP_USE_ORIG_PARAMS"] = str(args.fsdp_use_orig_params).lower() current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower() current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower() + current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower() if args.use_megatron_lm: prefix = "MEGATRON_LM_" From db4bac6a39588490d4279f216a9994b259b355ec Mon Sep 17 00:00:00 2001 From: Howard Liberty Date: Tue, 14 May 2024 15:01:39 -0700 Subject: [PATCH 2/2] Fix ruff errors --- src/accelerate/commands/config/cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/commands/config/cluster.py b/src/accelerate/commands/config/cluster.py index 26e33344d01..695cd784c7f 100644 --- a/src/accelerate/commands/config/cluster.py +++ b/src/accelerate/commands/config/cluster.py @@ -447,7 +447,7 @@ def get_cluster_input(): error_message="Please enter yes or no.", ) fsdp_config["fsdp_activation_checkpointing"] = _ask_field( - "Do you want to enable FSDP activation checkpointing? [yes/NO]: ", + "Do you want to enable FSDP activation checkpointing? [yes/NO]: ", _convert_yes_no_to_bool, default=False, error_message="Please enter yes or no.",