diff --git a/src/accelerate/commands/config/cluster.py b/src/accelerate/commands/config/cluster.py index 44fe8af32da..695cd784c7f 100644 --- a/src/accelerate/commands/config/cluster.py +++ b/src/accelerate/commands/config/cluster.py @@ -446,6 +446,12 @@ def get_cluster_input(): default=True, error_message="Please enter yes or no.", ) + fsdp_config["fsdp_activation_checkpointing"] = _ask_field( + "Do you want to enable FSDP activation checkpointing? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) megatron_lm_config = {} if distributed_type in [DistributedType.MULTI_GPU]: diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index 1932fe0592a..4189aacb63a 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -250,6 +250,7 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]: current_env["FSDP_USE_ORIG_PARAMS"] = str(args.fsdp_use_orig_params).lower() current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower() current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower() + current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower() if args.use_megatron_lm: prefix = "MEGATRON_LM_"