-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed Reshape Implementation #18068
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
wschin
force-pushed
the
wechi/d-reshape
branch
from
October 25, 2023 03:01
f2e05c9
to
409210e
Compare
wschin
force-pushed
the
wechi/d-reshape
branch
2 times, most recently
from
October 25, 2023 23:07
d3a5ba2
to
b0a7569
Compare
wschin
force-pushed
the
wechi/d-reshape
branch
from
October 26, 2023 17:27
b0a7569
to
47b3312
Compare
wschin
added a commit
that referenced
this pull request
Oct 26, 2023
Since Reshape may change device mesh from, e.g., [0, 1] to [0, 1, 0, 1], we can't assume same device mesh per op. At each operator, we replace a single operator-level device mesh - `device_mesh_shapes` - `device_mesh_elements` with per-tensor device meshes - `input_device_mesh_shapes` (input_device_mesh_shapes[i] is the device mesh's shape for the i-th input, e.g., "[3]" for 1-D mesh with 3 devices) - `input_device_mesh_elements` (input_device_mesh_elements[i] is the flattened device mesh elements for the i-th input; e.g., "[0, 1, 2]" if you have 3 devices in that mesh) - `output_device_mesh_shapes` - `output_device_mesh_elements` Check out the change in onnxruntime_test_distributed.py for examples. It's also heavily used in #18068's `onnxruntime_test_distributed.py` change.
wschin
force-pushed
the
wechi/d-reshape
branch
from
October 26, 2023 21:52
47b3312
to
baf8f35
Compare
souptc
approved these changes
Oct 27, 2023
kleiti
pushed a commit
to kleiti/onnxruntime
that referenced
this pull request
Mar 22, 2024
Since Reshape may change device mesh from, e.g., [0, 1] to [0, 1, 0, 1], we can't assume same device mesh per op. At each operator, we replace a single operator-level device mesh - `device_mesh_shapes` - `device_mesh_elements` with per-tensor device meshes - `input_device_mesh_shapes` (input_device_mesh_shapes[i] is the device mesh's shape for the i-th input, e.g., "[3]" for 1-D mesh with 3 devices) - `input_device_mesh_elements` (input_device_mesh_elements[i] is the flattened device mesh elements for the i-th input; e.g., "[0, 1, 2]" if you have 3 devices in that mesh) - `output_device_mesh_shapes` - `output_device_mesh_elements` Check out the change in onnxruntime_test_distributed.py for examples. It's also heavily used in microsoft#18068's `onnxruntime_test_distributed.py` change.
kleiti
pushed a commit
to kleiti/onnxruntime
that referenced
this pull request
Mar 22, 2024
This DistributedReshape aims at supporting all sharding patterns encountered in llama 2. All patterns found are tested in `TestDistributedReshape` in `onnxruntime_test_distributed.py`. This PR implements algorithms to compute the categories below. - All inputs and outputs are replica, so it's computed like a normal Reshape. - Two-axis fusion (if any of the inputs and outputs are sharded). This category convers, e.g., `[batch, seq, hidden] -> [batch x seq, hidden]`. - Two-axis decomposition (if any of the inputs and outputs are sharded). This category convers, e.g., `[batch x seq, hidden] -> [batch, seq, hidden]`. Review guideline: - Ignore the changes in sharding_spec.h and sharding_spec.cc since they come from another PR microsoft#18025. - First, read onnxruntime_test_distributed.py to get familiar with the input/output of DistributedReshape. - Second, check the new APIs in reshape.h/reshape.cc to expose CUDA Reshape kernel to DistributedReshape. - For DistributedReshape, check its `ComputeInternal` for the 3 categories mentioned above.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This DistributedReshape aims at supporting all sharding patterns encountered in llama 2. All patterns found are tested in
TestDistributedReshape
inonnxruntime_test_distributed.py
. This PR implements algorithms to compute the categories below.[batch, seq, hidden] -> [batch x seq, hidden]
.[batch x seq, hidden] -> [batch, seq, hidden]
.Review guideline:
ComputeInternal
for the 3 categories mentioned above.