-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add APIs to offload states of model, optimizer, and engine #6011
Conversation
Hi @tohtana , Thank you for your work. I've been trying the new APIs to test model offloading in a multi-model deployment (e.g., deepspeed-chat) as part of #5620 . Although the API works in offloading a model and reducing GPU memory initially, after bringing the model back and completing the first training iteration (i.e., optimiser states have been updated), I get a
|
Thank you for reporting, @kfertakis! I have an example script showing the usage of the APIs. Can you try this? |
So I tested the issue again with various models and it seems the problem is model-size related as it does not seem to occur for smaller models (i.e., <= 1B params, e.g., gpt2, gpt2-medium) and it does for bigger ones(i.e., OPT-1.3B, mistral-7B). Is there anything I could do to investigate it further and debug it? By the way, I should mention that I'm testing this in a single node, single GPU configuration (i.e., single worker) thus ZeRO3 partitioning should not have to partition data across other workers. I will also test the benchmark you referenced with an artificially larger model size setting. Thanks again. |
Hi @kfertakis, I tried this example with a 4B model but it worked. Can you try this on your environment? |
@tohtana, I wonder if it is useful to expose Similar to how @kfertakis, would love to get your thoughts as well on whether any of the above would be useful? Thanks! |
Hey, thanks for the comments. @tohtana, I've tried the example you provided and it does seem to work so I'm sharing a fork of the DeepSpeed-Examples repo to showcase the problem. I've modified the DeepSpeed-Chat code to use
this should lead to the @tjruwase thanks for the reference. Current problem aside, I can see how the helper functions can be useful in the future for ensuring consistency. thanks. |
Hi @kfertakis, thank you for sharing the repro. It seems that the actual issue is related to ZeRO3's prefetching. I opened #6557 as a workaround to address this issue. Can you try the branch |
Hi @tohtana, thank you for your work. I tried your branch and the issue seems to be fixed. I will continue testing and raise any new issues but for now, the |
I also wanted to ask if the offloading functionality could be extended to support |
@tjruwase Let me address this by another PR after this one is merged. |
Thank you @kfertakis for validating the fix.
Let me consider how to do this. Please feel free to open a new issue to track it as I am going to merge this PR first. |
Parameters prefetched by ZeRO3 are sometimes not used. This occurs when the actual sub-module execution differs from previous tracing. As a result, the state of the allgather handle for such a parameter remains `INFLIGHT`, causing functions like `empty_partition_cache` to detect it and throw an error. This PR resolves the issue by ensuring that communication finishes and the parameters are freed. As this issue was mentioned in #6011, this includes the change of the branch. We need to merge #6011 first. --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
This PR adds an API `deepspeed.runtime.zero.offload_states get_state_devices`, which gets devices of offload states as suggested in this [comment](#6011 (comment)). We could lift this up to `deepspeed.utils` but would need to resolve a circular import: User code -> `deepspeed.utils` -> `deepspeed.utils.offload_states` -> `deepspeed.runtime.zero` -> `deepspeed.runtime.zero.partition_parameters` -> `deepspeed.utils` This will require a significant refactoring as long as we have `OffloadStateTypeEnum` in `deepspeed.runtime.zero`. --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
* Safe usage of popen (#6490) Avoid shell=True security issues with Popen * Handle an edge case where `CUDA_HOME` is not defined on ROCm systems (#6488) * Handles an edge case when building `gds` where `CUDA_HOME` is not defined on ROCm systems * Update version.txt after 0.15.1 release (#6493) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.15.1 Author - @loadams Co-authored-by: loadams <[email protected]> * HPU: add required ENV vars to acccelerator init (#6495) Co-authored-by: Olatunji Ruwase <[email protected]> * Op_builder->is_compatible quite warning (#6093) Set the default value of op_builder/xxx.py/is_compatible()/verbose to False for quite warning. Add verbose judgement before op_builder/xxx.py/is_compatible()/self.warning(...). Otherwise the verbose arg will not work. --------- Co-authored-by: Logan Adams <[email protected]> * fix pipeline eval_batch micro_batches argument for schedule (#6484) Co-authored-by: Logan Adams <[email protected]> * Fix the broken url link (#6500) Simple changes to fix the Intel cpu example link and add more xpu examples. Signed-off-by: roger feng <[email protected]> * fix environment variable export bug for MultiNodeRunner (#5878) In some multi-node environment like SLURM,there are some environment vars that contain special chars and can trigger errors when being exported. For example, there is a var `SLURM_JOB_CPUS_PER_NODE=64(x2)` when requesting two nodes with 64 cpus using SLURM. Using `runner.add_export` to export this var will add a command `export SLURM_JOB_CPUS_PER_NODE=64(x2)` when launching subprocesses, while this will cause a bash error since `(` is a key word of bash, like: ``` [2024-08-07 16:56:24,651] [INFO] [runner.py:568:main] cmd = pdsh -S -f 1024 -w server22,server27 export PYTHONPATH=/public/home/grzhang/code/CLIP-2; export SLURM_JOB_CPUS_PER_NODE=64(x2); ... server22: bash: -c: 行 0: 未预期的符号“(”附近有语法错误 ``` This PR simply wrap the environment vars with a pair of `"` to make sure they are treated as string. Co-authored-by: Logan Adams <[email protected]> * Revert "BF16 optimizer: Clear lp grads after updating hp grads in hook" (#6508) Reverts microsoft/DeepSpeed#5328 After offline discussion with @YangQun1 , we agreed that there is no memory effect as clear_lp_grads flag triggers zero_() ops which just zeros buffers and does not free any memory. the outcome is compute overhead. * wrap include cuda_bf16.h with ifdef BF16_AVAILABLE (#6520) * Avoid security issues of subprocess shell (#6498) Avoid security issues of `shell=True` in subprocess --------- Co-authored-by: Logan Adams <[email protected]> * Add conditional on torch version for scaled_dot_product_attention (#6517) Changes from #4724 broke support for torch<2.0 in the flops profiler as the scaled_dot_product_attention [wasn't added](https://pytorch.org/docs/2.0/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention) until a beta version in torch 2.0 Resolved: #5534 Todo: - [ ] Test this - [ ] Issue resolution with users. * Added Intel Gaudi to Accelerator Setup Guide (#6543) Added Intel Gaudi to the list of accelerators in the setup guide. Co-authored-by: sakell <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Skip failing newly added tests in accelerate (#6574) Adding the new tests in https://github.com/huggingface/accelerate/pull/3097 caused the nv-accelerate-v100 tests to fail. Due to other CI issues we didn't notice this at first. This just skips the problematic test for now. cc: @stas00 / @muellerzr * Use msgpack for p2p comm (#6547) Use msgpack for P2P communication in pipeline engine. Co-authored-by: Logan Adams <[email protected]> * DeepNVMe perf tuning (#6560) Add performance tuning utilities: `ds_nvme_tune` and `ds_io`. Update tutorial with tuning section. --------- Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Joe Mayer <[email protected]> * [Accelerator] Cambricon MLU support (#6472) ### Description This PR includes Cambricon MLU accelerator support. With this PR, DeepSpeed supports MLU as backend for training and inference tasks. --------- Co-authored-by: Logan Adams <[email protected]> * Fix gradient accumulation for Z2+offload (#6550) The ZeRO 1/2 optimizer performs incorrect gradient accumulation in the path for ZeRO2 + Offloading. This issue is caused by two main reasons: 1) The micro_step_id in the ZeRO 1/2 optimizer is: - Initialized to 0 in the constructor. - Reset to -1 during the backward pass. For example, given a gradient accumulation step of 4, the micro_step_id changes as follows: - For the first global step: 1, 2, 3, 4. - Subsequently: 0, 1, 2, 3. 2) Gradients are copied to the buffer on the first micro step and accumulated in the buffer during the following micro steps. However, the current code incorrectly copies gradients at steps that are not at the accumulation boundary. This PR aligns the micro_step_id initialization in both the constructor and the backward pass, and corrects the condition for copying and accumulating gradients. Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> * fix errors when setting zero3 leaf modules with torch.compile (#6564) When setting zero3 leaf modules to a higher level module and running with torch.compile, there are a few errors from ZeROOrderedDict. First it doesn't support Deep copy for not having a constructor with no parameters. Second, it doesn't check the existence of ds_status attr on param before accessing the attr. change contributed by Haifeng Chen Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> * [XPU] Support DeepNVMe new code structure (#6532) In DeepNVMe GDS update, many functions are changed into a more abstract way. Also added some files. These change break zero-infinity on XPU. To bring this feature back, we have this PR: 1. modify the aio opbuilder for new files. 2. Add custom cpu_op_desc_t for xpu users. (XPU don't handle buffer aligned here) --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Add APIs to offload states of model, optimizer, and engine (#6011) This PR adds the following APIs to offload model, optimizer, and engine states. ```pytyon def offload_states(self, include: Container[OffloadStateTypeEnum] = None, device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, pin_memory: bool = True, non_blocking: bool = False) -> None: """Move the ZeRO optimizer buffers to the specified device. Arguments: include: Optional. The set of states to offload. If not provided, all states are offloaded. device: Optional. The device to move the ZeRO optimizer buffers to. pin_memory: Optional. Whether to pin the memory of the offloaded states. non_blocking: Optional. Whether to offload the states asynchronously. ... def offload_states_back(self, non_blocking: bool = False) -> None: ``` Here is the typical usage. ```python # Offload after forward, backward, and step model.offload_states() # Do something requiring a lot of device memory ... # Load states back to device memory model.offload_states_back() ``` You can selectively offload states to balance the offloading overhead and memory saving. ```python model.offload_states(include=set([OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.opt_states], device=OffloadDeviceEnum.cpu) ``` Performance (4.3B parameters / 4x A100) - Environment (4x A100, [benchmark script](https://gist.github.com/tohtana/05d5faba5068cf839abfc7b1e38b85e4)) - Average Device to Host transfer time: 2.45 GB/s, aggregated: 9.79 GB/s - Average Host to Device transfer: 11.05 GB/s, aggregated: 44.19 GB/s - Mem (allocated by PyTorch) - Before offload 18.2GB - After offloading 17.7MB - Time ([benchmark script](https://github.com/microsoft/DeepSpeedExamples/tree/tohtana/offload_states/training/offload_states), offloading time/loading time) python output_table.py | |pin_memory=0 non_blocking=0|pin_memory=0 non_blocking=1|pin_memory=1 non_blocking=0|pin_memory=1 non_blocking=1| |--:|---------------------------|---------------------------|---------------------------|---------------------------| | 1|4.34 / 3.42 |4.99 / 2.37 |6.5 / 2.42 |6.0 / 2.39 | | 2|9.9 / 3.28 |5.1 / 2.34 |6.21 / 2.42 |6.25 / 2.45 | | 3|9.92 / 3.19 |6.71 / 2.35 |6.33 / 2.38 |5.93 / 2.42 | | 4|9.55 / 2.82 |7.11 / 2.39 |6.9 / 2.38 |6.5 / 2.43 | | 5|4.4 / 3.35 |6.04 / 2.41 |6.26 / 2.41 |6.32 / 2.47 | | 6|4.4 / 3.57 |6.58 / 2.42 |6.88 / 2.4 |6.35 / 2.43 | | 7|9.51 / 3.12 |6.9 / 2.39 |6.9 / 2.39 |6.46 / 2.4 | | 8|4.77 / 3.64 |6.69 / 2.39 |7.39 / 2.42 |6.56 / 2.46 | | 9|9.5 / 3.07 |7.18 / 2.42 |6.67 / 2.39 |7.38 / 2.46 | TODO: - Enable offloading to a NVMe storage -> NVMe support is non-trivial. I suggest adding the support in another PR - [DONE] Discard buffer (and recreate it) instead of offloading. We don't need to restore the contiguous buffer for reduce. - [DONE] Check pin_memory improves performance or not --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> * add bfloat16 to inference support dtypes (#6528) to allow running inference tasks using bfloat16 --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Logan Adams <[email protected]> * [COMPILE] workflow for deepspeed + torch.compile (#6570) We use simple model + deepspeed zero 3 + torch.compile and count graph break numbers to demonstrate current status of combing deepspeed + torch.compile. --------- Co-authored-by: Masahiro Tanaka <[email protected]> * Fixes on the accelerate side mean we do not need to skip this test (#6583) HF accelerate implemented fixes here: https://github.com/huggingface/accelerate/pull/3131 This means we can revert the changes from #6574 * Fix torch include in `op_builder/mlu/fused_adam.py` and update no-torch workflow triggers (#6584) Changes from #6472 caused the no-torch workflow that is an example of how we build the DeepSpeed release package to fail (so we caught this before a release, see more in #6402). These changes also copy the style used to include torch in other accelerator op_builder implementations, such as npu [here](https://github.com/microsoft/DeepSpeed/blob/master/op_builder/npu/fused_adam.py#L8) and hpu [here](https://github.com/microsoft/DeepSpeed/blob/828ddfbbda2482412fffc89f5fcd3b0d0eba9a62/op_builder/hpu/fused_adam.py#L15). This also updates the no-torch workflow to run on all changes to the op_builder directory. The test runs quickly and shouldn't add any additional testing burden there. Resolves: #6576 * [ROCm] Fix subprocess error (#6587) Fixes https://github.com/microsoft/DeepSpeed/issues/6585 Use shell=True for subprocess.check_output() in case of ROCm commands. Do not use shlex.split() since command string has wildcard expansion. Signed-off-by: Jagadish Krishnamoorthy <[email protected]> * Cleanup CODEOWNERS file to be valid (#6603) * Add SSF Best practices badge (#6604) Work in progress to ensure we meet SSF best practices: https://www.bestpractices.dev/en/projects/9530 * Move V100 workflows from cuda 11.1/11.7 to 12.1 (#6607) * Fix SD workflow (#6609) SD workflow needed updates when we moved to pydantic 2 support that was never added before. Passing nv-sd workflow [here](https://github.com/microsoft/DeepSpeed/actions/runs/11239699283) * Pin accelerate to fix CI failures/issues (#6610) * Add llama3.2 vision autotp (#6577) Llama3.2-11b and llama3.2-90b including vision model and text model, these two models have different num_kv_heads, so we need to set num_kv_heads dynamically. Co-authored-by: Logan Adams <[email protected]> * Improve DS logging control (#6602) Disable `steps_per_print` by default. * Fix device selection using CUDA_VISIBLE_DEVICES (#6530) This PR addresses #5818. Instead of contiguous numbers based on the device count, this PR uses device indices in `--include`. --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Handle when `backend` is also in compile_kwargs (#6502) cc @tohtana Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> * Rearrange inference OPS and stop using builder.load (#5490) This PR mainly handles all places where InferenceBuilder is used to access any op or a specific implementation for an op. Instead an op is defined, and its proper implementation is picked inside and the usage will be transparent to the user. What was done in the PR: 1) Added missing ops (added a py file with fallback mechanism) 2) Added missing fallback implementations for existing ops 3) removed all usages for builder.load and replaced them with ops instead. 4) added workspace op and inferenceContext which contains all workspace related functions and inferenceContext is the python fallback of inferenceContext in CUDA 5) a small change to softmax_context signature to fit the fallback signature. --------- Co-authored-by: Joe Mayer <[email protected]> Co-authored-by: Lev Kurilenko <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> * Unpin accelerate tests, update lightning with node16 removal. (#6611) HF accelerate fixes implemented in https://github.com/huggingface/accelerate/pull/3145 mean that we no longer need to pin the Accelerate version! nv-lightning tests now run on Ubuntu 20.04+, so we support >node 16, so we can remove the explicit permissions for that in the env config. * Enabled Qwen2-MoE Tensor Parallelism (TP) inference (#6551) Modified _replace_module in auto_tp.py : The modification keeps the layers 'shared_expert_gate' and 'gate' in qwen2-moe the original type torch.nn.Linear and not changes them into LinearLayer. In this way, their weights will not be split into multiple HPU/GPU cards. Then the qwen2-moe can run on multiple HPU/GPU cards. Since the weights of 'gate' are not split into multiple HPU/GPU cards, all gather operations are not needed, which may improve performance. --------- Co-authored-by: Logan Adams <[email protected]> * Update version.txt after 0.15.2 release (#6615) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.15.2 Author - @jomayeri Co-authored-by: jomayeri <[email protected]> * Clean up prefetched parameters (#6557) Parameters prefetched by ZeRO3 are sometimes not used. This occurs when the actual sub-module execution differs from previous tracing. As a result, the state of the allgather handle for such a parameter remains `INFLIGHT`, causing functions like `empty_partition_cache` to detect it and throw an error. This PR resolves the issue by ensuring that communication finishes and the parameters are freed. As this issue was mentioned in #6011, this includes the change of the branch. We need to merge #6011 first. --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> * AIO CPU Locked Tensor (#6592) Restoring the functionality of the cpu locked tensor in the AIO library. Make async_io operator available for CPU accelerator, i.e., CPU only environment. --------- Co-authored-by: Olatunji Ruwase <[email protected]> * reduce setting global variables to reduce torch compile graph breaks (#6541) setting global variables during training will create a graph breaks when using torch.compile (reading global variables doesn't). this commit attempts to reduce the setting of global variables in the checkpointing flows. there are 2 main uses setting global variables: 1. Share data between functions 2. Establish that this is the first call to the code For most of the cases the data in the global variables is data that can be computed on demand or set once in an initial state in a configure function. For "check that this is the first run" use case the code was moved to the configure function. --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Add API to get devices of offload states (#6586) This PR adds an API `deepspeed.runtime.zero.offload_states get_state_devices`, which gets devices of offload states as suggested in this [comment](https://github.com/microsoft/DeepSpeed/pull/6011#issuecomment-2358068777). We could lift this up to `deepspeed.utils` but would need to resolve a circular import: User code -> `deepspeed.utils` -> `deepspeed.utils.offload_states` -> `deepspeed.runtime.zero` -> `deepspeed.runtime.zero.partition_parameters` -> `deepspeed.utils` This will require a significant refactoring as long as we have `OffloadStateTypeEnum` in `deepspeed.runtime.zero`. --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> * apply fp16 autocast only to floating point values * Ignore reuse_dist_env (#6623) Tests with `reuse_dist_env = True` often causes memory leaks. This PR ignores `reuse_dist_env` and forcibly sets it to `False`. This change might slow down the tests, but I think it is better to manually restart runners and relaunch tests. Memory usages (See #6578): - `reuse_dist_env == True`: https://github.com/microsoft/DeepSpeed/actions/runs/11302940871/job/31439471512 - `reuse_dist_env == False`: https://github.com/microsoft/DeepSpeed/actions/runs/11303250613/job/31440137894 * [compile] Show breakdown of graph break (#6601) This PR extends https://github.com/microsoft/DeepSpeed/pull/6570 by showing a breakdown of graph breaks. So we can see how graph breaks are distributed among different reasons. An example of graph break output can be seen from the following workflow run https://github.com/microsoft/DeepSpeed/actions/runs/11199157962 * Add API for updating ZeRO gradients (#6590) * Accept btl_tcp_if_include option through launcher_args (#6613) This patch fixes issue #4460. When `btl_tcp_if_include` option is provided through `--launcher_args`, we use the provided option instead of the hardcoded `--mca btl_tcp_if_include eth0`. Otherwise we use `--mca btl_tcp_if_include eth0` as the default for compatibility. Fixes #4460 --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> * Add first Step in LR Schedulers (#6597) Some (not all) of the LR schedulers in runtime were missing the initialization of the optimizer group lr. --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Support safetensors export (#6579) ## Feature This commit implements the following features: - [x] support saving checkpoint as safetensors (more commonly used format) - [x] support sharding checkpoints (which is important for very large models) Most of the codes are borrowed from https://github.com/huggingface/transformers/blob/v4.45.1/src/transformers/modeling_utils.py#L2490 ## Usage For `pytorch_model.bin` export ``` python zero_to_fp32.py . output_dir/ ``` For `model.safetensors` export ``` python zero_to_fp32.py . output_dir/ --safe_serialization ``` --------- Co-authored-by: Masahiro Tanaka <[email protected]> Co-authored-by: Logan Adams <[email protected]> * add option to disable logger while compiling to avoid graph breaks (#6496) adding an option to disable calls for logger while compiling to avoid graph breaks. Here I used an environment variable to determine whether to activate this option, but it can also be determined using the json config file or any other way you see fit. --------- Co-authored-by: snahir <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> * Lock cache file of HF model list (#6628) The error in the following log suggests that the cache file for HF model list can be broken: https://github.com/microsoft/DeepSpeed/actions/runs/11343665365/job/31546708118?pr=6614 The actual cause of the above error is unclear, but `_hf_model_list` potentially breaks the cache file when it is concurrently called from multiple processes. This PR locks the cache file to ensure `_hf_model_list` safely reads and writes the file. * Add README Pipeline Status for Huawei Ascend NPU (#6588) Hello! Following the merge of https://github.com/microsoft/DeepSpeed/pull/6445, I have implemented a CI pipeline to validate the Huawei Ascend NPU. --------- Co-authored-by: sjh <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> * Update torch version in workflows (#6631) Set PyTorch version in CI workflows to v2.5. Context: The [error](https://github.com/microsoft/DeepSpeed/actions/runs/11371525624/job/31633793986?pr=6630) in #6630 might have been caused by the PyTorch version mismatch or something. * Use file store for tests (#6632) This PR changes the `init_method` for tests to `FileStore` for robustness. * Fix Memory Leak In AIO (#6630) Fixing a memory leak in AIO pinned tensor as well as an incorrect function type for gds op. --------- Co-authored-by: Masahiro Tanaka <[email protected]> * [XPU] upgrade xpu max1100 CI workflow to pytorch2.3 (#6646) With intel-extension-for-pytorch=2.3.110 released last month, max1100 CI workflow can be updated too. Software versions aligned with #6570 . Increased CI tests scope for torch/ipex2.3 will be in later PR. This workflow passed in my cloned repo self-hosted runner. * [XPU] host timer check version from Torch 2.5 to Torch 2.6 (#6633) Elapsed time would be supported in Torch 2.6. Co-authored-by: Masahiro Tanaka <[email protected]> * [XPU] [DeepNVMe] use same cpu_op_desc_t with cuda (#6645) We have found that #6592 uses `_pinned_tensor_mgr` to create cpu bounce buffer, which is same with what our xpu accelerator currently doing. So no need to use xpu device specific cpu_op_desc_t. In this PR: 1. remove custom csrc/xpu/aio/deepspeed_cpu_op.cpp 2. modify xpu async_io opbuilder. This issue cannot be easily done with revert #6532 , for we added some source file as last time GDS feature going in DS. So file this new PR :) * Update version.txt after 0.15.3 release (#6652) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.15.3 Author - @jomayeri Co-authored-by: jomayeri <[email protected]> * Fix expert grad scaling problem with ZeRO optimizer (#6546) Fix [#6545] work: - expert gradient average: divide edp_world_size -> divide dp_world_size - unit test: make sure model with different dp/ep has same expert gradient --------- Co-authored-by: wangyiou <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Add attribute check for language_model when replace last linear module (#6650) Fix module has no attribute 'language_model' issue. Co-authored-by: Masahiro Tanaka <[email protected]> * fix init_device_mesh for torch 2.4 (#6614) Start torch 2.4, in [`init_device_mesh()`](https://github.com/pytorch/pytorch/blob/de4c2a3b4e89d96334dc678d1c3f2ae51a6630a0/torch/distributed/device_mesh.py#L915) ,device type with a GPU index, such as "cuda:0", is not allowed.  --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> * Fix dynamo issue (#6527) Dynamo use faketensor to trace tensor ops. In some case, the mechanism break compiling with deepspeed. An example could be found at https://gist.github.com/oraluben/9b8240c2fe482eb4382453d6c97a5f76, to see issues, install deepspeed==0.14.4 instead of my fork without this PR, llama cannot be compiled. Detailed explanation: 1. `ZeROOrderedDict` dynamo use deepcopy to copy tensors, which will call `object.__reduce__`. When copying `ZeROOrderedDict`, the default implementation do not copy its `_parent_module` and will lead to failure. 2. `param` maybe faketensor and do not have `ds_status` yet, but during tracing it's ok to just skip the `register_external_parameter`, it should be done ways before. --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> * sequence parallel for uneven heads (#6392) In sequence_parallel (Ulysses), the sequence parallel size is constrained by the requirement to be divisible by the number of heads, which prevents some models/workloads from setting a specific sequence parallel size. This PR implements uneven all-to-all heads splitting. - both support batch first (b,s,...) and seq_len first(s,b..) layout. - Added unit tests with numerical checks. Locally also tested with **7 heads with sp=4** and **20 heads with sp=8**, and it passed. --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Ma, Guokai <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> * Add fallback for is_compiling (#6663) Importing `torch.compiler.is_compiling` causes an error with an older version of PyTorch. This PR adds a fallback for `is_compiling` to use an equivalent function of older PyTorch versions. This will resolve #6656. Co-authored-by: Logan Adams <[email protected]> * Update profiler registration check (#6668) Resolves #5432. * Add support for H100/sm_90 arch compilation (#6669) Resolves: #6549 * Update Gaudi2 docker image (#6677) * Update gaudi2 docker version to latest release (1.18) (#6648) Updated docker version to 1.18.0-latest Note: for this update the firmware on the Gaudi2 node had to be updated to use firmware version 1.18. Co-authored-by: Logan Adams <[email protected]> * Update base docker image for A6000 GPU tests (#6681) Update to a [container (24.03)](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-03.html) with python 3.10 as transformers dropped support for python 3.8 in their latest release. Note: nv-human-eval.yml was never completed and isn't used, it is just updated for any potential future support. Resolves: #6672 * Remove packages that no longer need to be updated in the latest container (#6682) * Fix training of pipeline based peft's lora model (#5477) Hi, guys I find there is an assert failure when I train huggingface's lora based model in pipeline style. Here is the whole steps that I created my model: 1) Load the pre-trained chatglm-6b model from huggingface, as Model_A 2) Use huggingface's peft's `get_peft_model(...)` and my `LoraConfig(...)` from Model_A to create the lora model, as Model_B 3) Create my own pipeline based model Model_C from Model_B And I run Model_C under 2 3090ti GPUs. And the assertion failure looks like this: ```text Traceback (most recent call last): File "/home/ubuntu/proj/chatglm-finetuning/train_pipeline.py", line 372, in <module> main() File "/home/ubuntu/proj/chatglm-finetuning/train_pipeline.py", line 351, in main loss = engine.train_batch(data_iter=train_dataloader) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 375, in train_batch self._exec_schedule(sched) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 1375, in _exec_schedule self._exec_instr(**cmd.kwargs) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/runtime/pipe/engine.py", line 276, in _exec_reduce_tied_grads dist.all_reduce(grad, group=group) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/comm.py", line 117, in log_wrapper return func(*args, **kwargs) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/comm.py", line 496, in all_reduce return cdb.all_reduce(tensor, op, group, async_op) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/deepspeed/comm/torch.py", line 159, in all_reduce return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op) File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1520, in all_reduce _check_single_tensor(tensor, "tensor") File "/home/ubuntu/anaconda3/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 463, in _check_single_tensor raise RuntimeError( RuntimeError: Invalid function argument. Expected parameter `tensor` to be of type torch.Tensor. ``` After some debugging, I find out the root cause is that my configuration of lora (in below) only add extra lora layer(part) in qkv related layers but not the embedding layer. So the whole embedding layer's parameters are freezed. ```python lora_config = LoraConfig(r=8, # copied from finetuning_lora.py lora_alpha=32, target_modules=["query_key_value"], lora_dropout=0.1, bias="none", task_type="CAUSAL_LM", inference_mode=False, ) ``` And in my implementation of pipeline based model, I declared the embeding layer as a tied-layer. So the whole thing is that there are no gradients at all for embedding layer, but embedding layer as the tied layer needs to be synced between two gpus. The value of gradient is None but is still passed to `all_reduce` operation. Current, my fix is simple and add a check if this `grad` is None. --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Heyang Qin <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> * Update checkout action to latest version (#5021) Latest checkout uses latest (non-deprecated) version of node (16 -> 20). More information [here](https://github.blog/changelog/2023-09-22-github-actions-transitioning-from-node-16-to-node-20/): ``` Node.js 16 actions are deprecated. Please update the following actions to use Node.js 20: actions/checkout@v3. For more information see: https://github.blog/changelog/2023-09-22-github-actions-transitioning-from-node-16-to-node-20/. ``` Checkout action: https://github.com/actions/checkout Node 20 requires a minimum of Ubuntu 20.04, so workflows currently using 18.04 are failing/will fail. * Add attribute check to support git-base autotp (#6688) Git-base model is an image-text model. After supporting the llama3.2 vision model, we set num_kv_heads dynamically. Git-base only includes vision_config, so we need to add an attribute check for vision_config/text_config when setting num_kv_heads. Co-authored-by: Logan Adams <[email protected]> * fix memcpy issue on backward for zero-infinity (#6670) This PR is similar to [PR#5301](https://github.com/microsoft/DeepSpeed/pull/5301), that optimizes the D2H time use pinned memory. Previously, the D2H memcpy will be the bottleneck during the final backward pass of each iteration for ZeRO-Infinity(offload), as shown in Trace-1. The new version can eliminate the bottleneck, as shown in Trace-2. _Trace-1_ <img width="480" alt="image" src="https://github.com/user-attachments/assets/891e3770-351b-4e03-8a59-b491bc44d03b"> _Trace-2_ <img width="192" alt="image" src="https://github.com/user-attachments/assets/f1cf9037-77f8-42a6-adc8-d5c6bacde0aa"> cc @tjruwase --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> * Free memory in universal checkpointing tests (#6693) Tests in universal checkpointing were not freeing the engine after use when `reuse_dist_env` was set to `True`, leading to memory leaks. This PR ensure freeing the engine in the tests and enables `reuse_dist_env`. Co-authored-by: Logan Adams <[email protected]> * Explictly set device when reusing dist env (#6696) A rank of a process can change when reusing the environment. This PR explicitly sets the device when reusing the environment. * Update URL in README Pipeline Status for Huawei Ascend NPU (#6706) * Pin transformers to 4.45.2 in nv-ds-chat workflow (#6710) This commit causes breaking changes we need to fix, for now we will pin the version but we will fix shortly https://github.com/huggingface/transformers/pull/33325 * [Bug Fix] Support threads_per_head < 64 for wavefront size of 64 (#6622) When launching apply_rotary_pos_half kernel, only threads_per_head of 64 is supported for wavefront size of 64. This change adds support for threads_per_head < 64 such as 4, 8, 16. Fixes the issue introduced in https://github.com/microsoft/DeepSpeed/pull/5402 --------- Signed-off-by: Jagadish Krishnamoorthy <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Use one param coordinator for both train/inference scenarios (#6662) The parameter coordinator in ZeRO3 throws a "backward pass is invalid for module in evaluation mode" error when the training mode is unexpected, as it expects all modules to be in training mode during the backward pass. This is an unnecessarily strict restriction. This PR relaxes the restriction by using a single parameter coordinator (instead of separate ones for training and evaluation modes) and resetting the prefetch state before starting a forward pass. Use of `is_compiling` needs to be fixed after #6663 is merged. --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Update yapf version (#6721) This update is needed to support eventually running on ubuntu-24.04 from GitHub, specifically because the python version is updated to 3.12 and results in the following error: `ModuleNotFoundError: No module named 'lib2to3'` since that package is deprecated. * Update flake8 version (#6722) This PR is useful for updating the flake8 checks we run, but is mostly needed to update flake8 so that it can run on newer versions of python which are included in newer ubuntu-latest versions from GitHub that we update to in #6717 * Switch what versions of python are supported (#5676) Add support for testing compilation with python 3.11/3.12. Also add the dockerfiles used to build those images. --------- Co-authored-by: Michael Wyatt <[email protected]> * Update version.txt after 0.15.4 release (#6731) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.15.4 Author - @loadams Co-authored-by: loadams <[email protected]> * Update GH hosted workflows to 24.04 (#6717) `ubuntu-latset` is moving to be 24.04, so we should test updating as well to ensure it doesn't break any of our workflows. * Add COMMITTER file (#6741) Add COMMITTER file * Update AMD apex version (#6739) * Fix Type Name Inconsistency & Typo in cpu_adam (#6732) There is a typing error & inconsistency in cpu-adam code, while not affecting functionality, impacts code readability. Specifically, the type name `ds_params_percision_t` contains a typo ('percision'), whereas the related type name `ds_state_precision_t` is spelled correctly. I think it is beneficial to fix this typo&inconsistency to improve code readability, maintainability and further development. I have tested the corrected version of cpu_adam, and it compiles and runs successfully. Compilation Log: <img width="2560" alt="image" src="https://github.com/user-attachments/assets/b7bc307d-9c9d-4ab7-8671-34e565903ca5"> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> * Add Domino code (#6733) add domino code Co-authored-by: Logan Adams <[email protected]> * Add data type check for bf16 (#6742) Add data type check for bf16 to fix #6723 * add zero3 ```module_granularity_threshold ``` to zero optimization. (#6649) This PR adds Z3 coalesced fetch to zero optimization. Currently, some logic can be reused, but it's difficult to realize that as optimization choice(I only discovered these logic when trying to implement it). The benefit of this approach is reducing host overhead(reduce many hooks) and during the process of recursive fetching parameters (especially in fine-grained models, such as those with a large number of moe experts). This is particularly helpful for host-sensitive devices (such as hpu), where it achieved a 40% performance improvement in our customer workloads. FYI @delock @deepcharm --------- Co-authored-by: Ma, Guokai <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> * AIO File Offsets (#6641) Adding the option for a file offset to the read/write functions of AIO & GDS ops. --------- Co-authored-by: jomayeri <deepspeed@H100-VM2.shlnn55tgwve1eacvp21ie45dg.jx.internal.cloudapp.net> Co-authored-by: Masahiro Tanaka <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Update path for BingBertSquad from DeepSpeedExamples (#6746) In https://github.com/microsoft/DeepSpeedExamples/pull/245, the DeepSpeedExamples directory structure was refactored, this updates the DeepSpeed examples from those changes. * Sanitize inputs to eval() (#6745) * Adding the governance doc (#6748) Drafted governance doc for the LFAI. Co-authored-by: Minjia Zhang <[email protected]> * Add no_sync context manager (#6675) Fix #1902 --------- Co-authored-by: Logan Adams <[email protected]> * Gaudi2 Nightly job for daily check (#6753) Co-authored-by: Logan Adams <[email protected]> * Disable failing python tests (#6758) * A faster and more memory-efficient implementation of `zero_to_fp32` (#6658) It is a faster and more memory-efficient implementation of `zero_to_fp32`. The previous version double the memory usage, which cause cpu OOM for very large models (e.g. llama 405B). https://github.com/microsoft/DeepSpeed/blob/b647fb2470f8f6fefe5cab0ea84a2d89696eb898/deepspeed/utils/zero_to_fp32.py#L438-L441 ## How does it work? 1. **Lazy loading**: Load checkpoint with `mmap=True`, thus the weights are mmaped rather than loading all the storages into memory. 2. **Lazy merge**: `GatheredTensor` contains the mmaped weights and tensor offset. It is a memory-efficient pseudo tensor. Only when `tensor.contiguous()` is called, it starts to load related weights to memory and merge into a single tensor. 3. **Release memory in time**: Save checkpoints shard by shard, and release the memory once a shard is saved. Throughout the process, only one shard of tensors are keeped in memory. ## How much benefit in speed and memory ? Experiments were conducted on a linux host with 1TB of memory. Here is a detailed comparision | | world size | peak memory(GB) | elapsed time(h:mm:ss) | |----------------------|------------|--------------|--------------------| | llama3-8B(old->new) | 8 | 90 -> 41 | 0:02:17 -> 0:01:10 | | llama2-13B(old->new) | 8 | 146 -> 54 | 0:02:30 -> 0:01:47 | | llama2-70B(old->new) | 16 | 789 -> 159 | 0:20:47 -> 0:20:45 | | qwen1.5-110B(old->new) | 32 | OOM -> 217 | ? -> 0:34:21 | | llama3-405B(old->new) | 192 | OOM -> 262 | ? -> 2:09:59 | You can reproduce with the following scripts ```sh # 1. install requirments apt-get install time # 2. prepare zero-3 checkpoints # 3. convert zero to fp32 checkpoints /usr/bin/time -v python zero_to_fp32.py . output_dir/ --safe_serialization ``` - **memory**: Theoretically, this PR reduces the memory cost from `2M` to `(1/n)M`, where `M` is the memory cost of the full weights, `n` is num_shards. - **speed**: The speed gain mainly comes from avoiding extra tensor copying. The benifit may be slight. ## Impl history - [v1](https://github.com/xu-song/DeepSpeed/commit/19712a1c75bfc1da4a7f3ecca6915a86af671568#diff-6a2ca3427fa608c387b7351359f98cfc1313be6e960cee86344ff246bf1b8326R441-R447) : a hf_hub compatible approach. It has been discarded due to the controversial implementation of `data_ptr().` - [v2](https://github.com/microsoft/DeepSpeed/pull/6658/files): a simple approach with `torch.empty` --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Pin transformers version to work around latest torch requirements (#6759) Latest transformers seems to break our tests that aren't on torch latest (>= 2.5). Issue opened here: https://github.com/huggingface/transformers/issues/34795. This pins our version so these tests can pass in the meantime. * make xpu ops compatible with oneapi 2025.0 (#6760) Compatibility update for xpu ops This PR introduces changes that will make xpu ops compatible with the OneAPI 2025.0 toolkit. This is an important update that will allow us to develop and ship our most demanding models on this innovative hardware. --------- Signed-off-by: baodii <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Add explicit parameters for torch.load (#6751) Successor PR to #6094: > FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. Todo: - [ ] Update values in non-test files to True where necessary. * Fix setup.py bash cmd generation to correctly extract git info (#6762) Co-authored-by: Logan Adams <[email protected]> * Use `json_schema_extra` instead of extra keyword in `Field` (#6764) > Using extra keyword arguments on `Field` is deprecated and will be removed. Use `json_schema_extra` instead. (Extra keys: 'new_param'). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.9/migration/ Co-authored-by: Logan Adams <[email protected]> * Enable torch compile on _allgather_params (#6769) * Previosuly ZerO3 was crashing when trying to compile _allgather_params * Disabling grad solves the issue * Removes unnecessary cloning (#6761) `clone_tensors_for_torch_save()` function: When the `item.device` is different from `device` input, `tensor.clone()` is not actually required because `to()` function also clones the original tensor. +) I observed memory bloat under following conditions: * Training a Whisper model w/ `transformers` framework with `ZeRO-0` and `ZeRO-1` configuration. * Memory bloating can be observed every time the model state_dict is cloned using `clone_tensors_for_torch_save()` After I removed the unnecessary `clone()`, seems like the problem is solved. Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> * Fix potential memory issues when use deepspeed Z3 (#6726) I had OOM problem when doing DPO training using zero3. It needs to call module twice in one training step, and second call is with no_grad(). The problem is caused by two bugs: 1. "__n_available_params", which helps to control fetched parameters, becomes negative after release_and_reset_all() function. 2. module.ds_grads_remaining becomes negative in backward() if we call module more than once in one training step. I tried to create two patches to fix these issues. --------- Signed-off-by: Wenbin Chen <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Hongwei Chen <[email protected]> * Unpin with latest transformers fixes (#6763) Reverts #6759 Requires from transformers: https://github.com/huggingface/transformers/pull/34816 https://github.com/huggingface/transformers/pull/34800 Todo: - [x] Need to merge first PR to get support for torch 2.4 * docs: fix HF links (#6780) The current link https://huggingface.co/docs/transformers/main_classes/deepspeed is very unhelpful. It turns out in the past it had some guides: https://huggingface.co/docs/transformers/v4.27.1/main_classes/deepspeed#shared-configuration Later it's refreshed and moved to https://huggingface.co/docs/transformers/deepspeed * Fix Doc Error: ZeRO Stage 2 gradient partitioning (#6775) Fix the issue described in https://github.com/microsoft/DeepSpeed/issues/6707 * Cleanup code docs warnings (#6783) We have a number of warnings in our readthedocs sphinx/autodoc .rst files, so this cleans some of those up so we can fix real issues there. * Domino Blog (#6776) This PR is domino blog on our public side. cc @tjruwase --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Update version.txt before release (#6784) * Revert release workflow (#6785) * Update version.txt after 0.16.0 release (#6786) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.16.0 Author - @loadams Co-authored-by: loadams <[email protected]> * Domino news update on readme.md (#6815) * Fix zero checkpoint (#6792) Fix #6791 --------- Co-authored-by: Logan Adams <[email protected]> * Update python version but now we need to include setuptools on our own (#6787) TODO: - [x] determine if this means we should technically add setuptools to the requirements.txt * Adding the new feature of FPDT (#6462) [FPDT](https://arxiv.org/abs/2408.16978) can only be used with [this version](https://github.com/microsoft/Megatron-DeepSpeed/pull/441) of Megatron-DeepSpeed. --------- Co-authored-by: Jinghan Yao <[email protected]> Co-authored-by: Sam Ade Jacobs <[email protected]> Co-authored-by: Jinghan Yao <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Jinghan Yao <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]> * Pin transformers to avoid errors with latest version (#6820) * Ulyssess offload blog (#6814) Ulysses-Offload (FPDT) blog, please see corresponding tutorial page at [link](https://github.com/microsoft/DeepSpeed/pull/6813). --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Logan Adams <[email protected]> * add FPDT tutorial (#6813) Tutorial page for Ulysses-Offload (FPDT), blog page to follow. --------- Co-authored-by: Jinghan Yao <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Update README.md (#6824) Fix broken tutorial link * Update README.md (#6825) Add Ulysses-offload to News page Co-authored-by: Logan Adams <[email protected]> * Pin transformers version in cpu-torch-latest due to multiprocessing error. (#6823) This is a copy of https://github.com/microsoft/DeepSpeed/pull/6820 for the cpu-torch-latest tests. This PR will revert/fix these: https://github.com/microsoft/DeepSpeed/pull/6822 * Update pre-commit version (#6821) * Update version.txt after 0.16.1 release (#6826) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.16.1 Author - @loadams Co-authored-by: loadams <[email protected]> * Pin HPU tests (#6831) HPU tests are impacted by the same issue as other tests that use transformers latest. This PR pins to a version of transformers before the fix. * Flops profiler support einops.einsum (#6755) - Added support for FlopsProfiler to include einops.einsum operation - Added _patch_miscellaneous_operations() and _reload_miscellaneous_operations() to include this operation and potentially include other miscellaneous operations in the future - Added _einops_einsum_flops_compute() that mimic already-existed _einsum_flops_compute() --------- Co-authored-by: Logan Adams <[email protected]> * Pin pytest-subtests version for accelerate tests (#6842) * Inference UTs check for trition support from accelerator (#6782) Instead of checking if installed or not check for support. Skip if not supported. Co-authored-by: Logan Adams <[email protected]> * Unpin pytest-subtests now that 0.14.1 is released (#6844) The issue we encountered was covered here: https://github.com/pytest-dev/pytest-subtests/issues/173 And is resolved with the latest changes from this PR: https://github.com/pytest-dev/pytest-subtests/issues/174, and is published in the latest version 0.14.1. * Merge LoCo with Zero++ (#6730) ### Integration of LoCo Method into ZeRO++ #### Overview This PR introduces the integration of the **LoCo** method, as outlined in [this paper](https://arxiv.org/abs/2407.04480), into the ZeRO++ framework of DeepSpeed. The key enhancement involves applying error feedback compensation to 4-bit gradients before communication. This approach ***improves pre-training loss outcomes without additional time overhead***, though it requires extra GPU memory. The extent of this memory increase depends on model size and training configuration. #### Experimental Results We conducted pre-training experiments using the Llama2 architecture, adjusting the number of layers and hidden size. The experiments included: - **A smaller-scale model with 0.8B parameters trained on 30B tokens**. - **A larger-scale model with 8B parameters trained on 5B tokens**. The training data was sampled from **Redpajama-V2**. <p align="center"> <img src="https://github.com/user-attachments/assets/e7db9487-728c-4a17-9806-c15afa12f62e" width="49%" /> <img src="https://github.com/user-attachments/assets/3efec895-b71d-43ab-b5ce-65468ba8b9f1" width="49%" /> </p> **Findings**: - **Smaller Models (0.8B parameters)**: Significant gains were observed when applying the LoCo method. - **Larger Models (8B parameters)**: The gains were present but less pronounced. This could be due to: 1. Relatively smaller data volume. 2. Lower pre-training loss for larger models, making significant improvements harder to achieve. However, even a smaller pre-training loss gap in larger models can translate to meaningful gains in downstream tasks. #### Example Script For reference, the [run.sh](https://github.com/user-attachments/files/17679552/zeroplus-7b3.zip) script used for the 8B parameter, 5B tokens experiment is attached. The experiment was conducted using the **DeepSpeed-Megatron** platform. #### Acknowledgments Special thanks to cc @GuanhuaWang for ongoing communication and guidance throughout this work. --- We appreciate your consideration of this PR and welcome any feedback or questions! --------- Co-authored-by: ChuanxinTang <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Hongwei Chen <[email protected]> * Fix type error in `ZeROOrderedDict` (#6794) As @keskival pointed in https://github.com/microsoft/DeepSpeed/commit/3d5cf739ead7c78f518a518ccaa15a323bd5c8da#r149582004, I've confirmed there's a type error, which this PR fixes. I didn't run into this because our internal version still use `*r2`. Co-authored-by: Tero Keski-Valkama <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Fix uneven head sequence parallelism bug (#6774) (#6797) Here `gather_idx < 2` represents `is_first_all2all`. During the first all2all, `uneven_head_all2all` will be called if either `num_heads % seq_world_size != 0` or `get_num_kv_heads() is None`. During the second all2all, it'll return return `uneven_head_all2all` if and only if `get_num_kv_heads() is None` which is always set during the first uneven all2all. This means that there will no longer be issue where `uneven_head_all2all ` is returned for the second all2all because of `num_heads % seq_world_size != 0`. Fixes: #6774 --------- Co-authored-by: Logan Adams <[email protected]> * Fix nv-torch-nightly test by pinning transformers (#6849) * Remove broken links to non-active site (#6854) The site referenced in various places on the README is no longer active: https://deepspeed4science.ai  Co-authored-by: Logan Adams <[email protected]> * Avoid poisoning process with CUDA calls as soon as importing (#6810) Call `torch.cuda.device_count() > 0` before `torch.cuda.is_available()`, to give priority to nvml based availability, so that we can try not to poison process with CUDA calls as soon as we execute `import deepspeed`. https://github.com/pytorch/pytorch/blob/v2.5.1/torch/cuda/__init__.py#L120-L124 There are 2 reasons to make this change: Firstly, if we accidentally import deepspeed, since the CUDA runtime initializes when the first CUDA API call is made and caches the device list, changing the CUDA_VISIBLE_DEVICES within the same process after initialization won't have any effect on the visible devices. The specific case: https://github.com/OpenRLHF/OpenRLHF/pull/524#issuecomment-2501505023 A demo for reproduction before the fix is applied: ```python import torch import os os.environ["CUDA_VISIBLE_DEVICES"] = "" import deepspeed os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" torch.cuda.set_device('cuda:0') ``` Secondly, https://pytorch.org/docs/stable/notes/cuda.html When assessing the availability of CUDA in a given environment (is_available()), PyTorch’s default behavior is to call the CUDA Runtime API method cudaGetDeviceCount. Because this call in turn initializes the CUDA Driver API (via cuInit) if it is not already initialized, subsequent forks of a process that has run is_available() will fail with a CUDA initialization error. Signed-off-by: Hollow Man <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Fix xpu tests workflow failure by changing pip index url (#6864) Update xpu-max1100.yml and xpu-compile.yml * Domino updates (#6861) Updating our website for Domino --------- Co-authored-by: Logan Adams <[email protected]> * add domino navigation (#6866) add domino item into navigation list * Update TSC (#6867) * Remove warnings from autodoc and sphinx (#6788) Co-authored-by: Olatunji Ruwase <[email protected]> * Update real_accelerator.py (#6845) ### Comment out or delete `accelerate_name="cpu"` when `xpu` is not detected. When `xpu `is not detected it just pass at lines from 68 to 74 if `DS_ACCELERATOR` is set. However, `cpu` is assigned to `accelerate_name` if it cannot import `intel_extension_for_pytorch` or find` xpu`, namely, at line from 125 to 133 when`DS_ACCELERATOR` is not set. I found this problem yesterday and spent whole afternoon figuring it out. I got `intel_extension_for_pytorch `installed with other package which I do not use actually and have no idea about this. Then I found that it `cpu` is assigned to accelerate_name directly if it cannot find `xpu` and it affects `cuda` detection. In fact, `cpu` will be assigned finally if `cuda` is even not detected at line from 170 to 177. --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Fix assertion for offloading states (#6855) This PR fixes the assertions in `offload_states` method mentioned in #6833. Co-authored-by: Logan Adams <[email protected]> * Remove pin from transformers version and fix Processing/Threading issues in tests (#6822) Changes from https://github.com/huggingface/transformers/pull/34966 caused the `nv-torch-latest-v100` tests to fail with the following error: ``` File "/tmp/azureml/cr/j/e4bfd57a509846d6bbc4914639ad248d/exe/wd/actions-runner/_work/DeepSpeed/DeepSpeed/unit-test-venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3941, in from_pretrained raise EnvironmentError( OSError: Can't load the model for 'hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2' is the correct path to a directory containing a file named pytorch_model.bin, tf_model.h5, model.ckpt or flax_model.msgpack. ``` Sample failure here: https://github.com/microsoft/DeepSpeed/actions/runs/12169422174/job/33942348835?pr=6794#step:8:3506 This was resolved on the Transformers side here: https://github.com/huggingface/transformers/pull/35236 * Add MLP/lm_head tp grain size setting. (#6828) This PR aims to add MLP/lm_head tp size granularity setting to deepspeed.init_inference() API. It will be more flexible to set the MLP/lm_head sharding grain size. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size. We aim to be able to set the MLP/lm_head tp grain size flexibly. This is a preliminary solution. If there is a better solution, we can discuss it together. Thanks~ --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> * Fix --enable_each_rank_log when used with PDSH multi-node runner (#6863) This PR addresses fixes https://github.com/microsoft/DeepSpeed/issues/6859 by threading this argument into the deepspeed launcher command build by PDSHRunner. --------- Co-authored-by: Logan Adams <[email protected]> * Update transformers ops unit tests to use `requried_torch_version` (#6884) * Don't error out when cpu accelerator doesn't have torch (as default for whl building) (#6886) This fixes a bug introduced in #6845, which breaks the `no-torch` workflow that we require in order to do releases where we do not require torch to be in the environment when building an sdist. This adds the same logic to the cpuaccelerator that the cudaaccelerator had where we don't require torch to be installed to build the whl. * Add arctic model support by adding w2 to all_reduce (#6856) As title says. Default behavior of arctic model produces shape issues with AutoTP due to the MLP layer performing `w2 * act(w1*w3)`. However, method provided to fix Mixtral-7x8b in #5257 does not work since the MLP for Arctic is also used within a ModuleList for the MoE. This results in MLP weights hiding behind individual experts as layers `#.w#`, which is not caught by the fix in #5257. This adds the check directly within replace, where it can check for actual layer names for the `w2` key in the model to patch with `all_reduce`. --------- Signed-off-by: Daniel Huang <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Update code owners (#6890) Co-authored-by: Logan Adams <[email protected]> * Update version.txt after 0.16.2 release (#6893) **Auto-generated PR to update version.txt after a DeepSpeed release** Released version - 0.16.2 Author - @loadams Co-authored-by: loadams <[email protected]> * Allow to compile collective for PT>2.3 (#6899) Allow to compile collective for PT>2.3 commit re-uploaded due to github CI issue originally uploaded by @nelyahu * Zero2: avoid graph breaks in torch.compile by using param_idx (#6803) inside reduce_independent_p_g_buckets_and_remove_grads and in reduce_ipg_grads which are being executed during the BWD hook in zero2, the model param is being stored inside params_in_ipg_bucket. torch.compile has hard time tracing parameters. By using the param's static index inside the group the same logic can be maintain with less complexity. --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Logan Adams <[email protected]> * hpu_accelerator: use torch.use_deterministic_algorithms (#6897) formal API instead of hpu.setDeterministic * Fix error caused by all_reduce call in domino (#6880) Fix #6851 Initialize communication backend to fix error caused by all_reduce call in the Domino transformer layer. Verified correctness in local test. --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> * Update Gaudi2 jobs to latest 1.19 build (#6905) Co-authored-by: Logan Adams <[email protected]> * Change compile for pipeline module torch.compile (#6478) We have encountered and issue with torch.compile and the pipeline module. modifying a member of the module (micro_offset) during the forward function will cause torch compile to restart the analysis and treat the module as dynamic. In order to bypass this issue without significantly changing the way the pipeline module works we propose to compile only the layers in the pipeline module instead of the forward function of pipeline module. this will bypass the issue and should still give most of the benefit of torch compiling the pipeline module while avoiding the issue. --------- Co-authored-by: Logan Adams <[email protected]> * Stage3: Use new torch grad accumulation hooks API (#6773) * This commit addresses a Deepspeed issue [#6718](https://github.com/microsoft/DeepSpeed/issues/6718) * The existing code has been using the grad_acc node hook to reduce params grads. The constructs such as `param.data = replicated_tensor.data` used in `allgather_pa…
This PR adds the following APIs to offload model, optimizer, and engine states.
Here is the typical usage.
You can selectively offload states to balance the offloading overhead and memory saving.
Performance (4.3B parameters / 4x A100)
python output_table.py
TODO: