Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Distributed Reshape Implementation (#18068)
This DistributedReshape aims at supporting all sharding patterns encountered in llama 2. All patterns found are tested in `TestDistributedReshape` in `onnxruntime_test_distributed.py`. This PR implements algorithms to compute the categories below. - All inputs and outputs are replica, so it's computed like a normal Reshape. - Two-axis fusion (if any of the inputs and outputs are sharded). This category convers, e.g., `[batch, seq, hidden] -> [batch x seq, hidden]`. - Two-axis decomposition (if any of the inputs and outputs are sharded). This category convers, e.g., `[batch x seq, hidden] -> [batch, seq, hidden]`. Review guideline: - Ignore the changes in sharding_spec.h and sharding_spec.cc since they come from another PR #18025. - First, read onnxruntime_test_distributed.py to get familiar with the input/output of DistributedReshape. - Second, check the new APIs in reshape.h/reshape.cc to expose CUDA Reshape kernel to DistributedReshape. - For DistributedReshape, check its `ComputeInternal` for the 3 categories mentioned above.
- Loading branch information