Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed Reshape Implementation #18068

Merged
merged 2 commits into from
Oct 27, 2023
Merged

Distributed Reshape Implementation #18068

merged 2 commits into from
Oct 27, 2023

Conversation

wschin
Copy link
Contributor

@wschin wschin commented Oct 23, 2023

This DistributedReshape aims at supporting all sharding patterns encountered in llama 2. All patterns found are tested in TestDistributedReshape in onnxruntime_test_distributed.py. This PR implements algorithms to compute the categories below.

  • All inputs and outputs are replica, so it's computed like a normal Reshape.
  • Two-axis fusion (if any of the inputs and outputs are sharded). This category convers, e.g., [batch, seq, hidden] -> [batch x seq, hidden].
  • Two-axis decomposition (if any of the inputs and outputs are sharded). This category convers, e.g., [batch x seq, hidden] -> [batch, seq, hidden].

Review guideline:

  • Ignore the changes in sharding_spec.h and sharding_spec.cc since they come from another PR Support per-tensor device mesh at op level #18025.
  • First, read onnxruntime_test_distributed.py to get familiar with the input/output of DistributedReshape.
  • Second, check the new APIs in reshape.h/reshape.cc to expose CUDA Reshape kernel to DistributedReshape.
  • For DistributedReshape, check its ComputeInternal for the 3 categories mentioned above.

@wschin wschin force-pushed the wechi/d-reshape branch 2 times, most recently from d3a5ba2 to b0a7569 Compare October 25, 2023 23:07
@wschin wschin marked this pull request as ready for review October 26, 2023 00:11
@wschin wschin requested a review from souptc October 26, 2023 00:11
wschin added a commit that referenced this pull request Oct 26, 2023
Since Reshape may change device mesh from, e.g., [0, 1] to [0, 1, 0, 1],
we can't assume same device mesh per op. At each operator, we replace a
single operator-level device mesh
- `device_mesh_shapes`
- `device_mesh_elements`

with per-tensor device meshes
- `input_device_mesh_shapes` (input_device_mesh_shapes[i] is the device
mesh's shape for the i-th input, e.g., "[3]" for 1-D mesh with 3
devices)
- `input_device_mesh_elements` (input_device_mesh_elements[i] is the
flattened device mesh elements for the i-th input; e.g., "[0, 1, 2]" if
you have 3 devices in that mesh)
- `output_device_mesh_shapes`
- `output_device_mesh_elements`

Check out the change in onnxruntime_test_distributed.py for examples.
It's also heavily used in #18068's `onnxruntime_test_distributed.py`
change.
@souptc souptc merged commit 9c32310 into main Oct 27, 2023
@souptc souptc deleted the wechi/d-reshape branch October 27, 2023 05:33
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
Since Reshape may change device mesh from, e.g., [0, 1] to [0, 1, 0, 1],
we can't assume same device mesh per op. At each operator, we replace a
single operator-level device mesh
- `device_mesh_shapes`
- `device_mesh_elements`

with per-tensor device meshes
- `input_device_mesh_shapes` (input_device_mesh_shapes[i] is the device
mesh's shape for the i-th input, e.g., "[3]" for 1-D mesh with 3
devices)
- `input_device_mesh_elements` (input_device_mesh_elements[i] is the
flattened device mesh elements for the i-th input; e.g., "[0, 1, 2]" if
you have 3 devices in that mesh)
- `output_device_mesh_shapes`
- `output_device_mesh_elements`

Check out the change in onnxruntime_test_distributed.py for examples.
It's also heavily used in microsoft#18068's `onnxruntime_test_distributed.py`
change.
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
This DistributedReshape aims at supporting all sharding patterns
encountered in llama 2. All patterns found are tested in
`TestDistributedReshape` in `onnxruntime_test_distributed.py`. This PR
implements algorithms to compute the categories below.
- All inputs and outputs are replica, so it's computed like a normal
Reshape.
- Two-axis fusion (if any of the inputs and outputs are sharded). This
category convers, e.g., `[batch, seq, hidden] -> [batch x seq, hidden]`.
- Two-axis decomposition (if any of the inputs and outputs are sharded).
This category convers, e.g., `[batch x seq, hidden] -> [batch, seq,
hidden]`.

Review guideline:
- Ignore the changes in sharding_spec.h and sharding_spec.cc since they
come from another PR microsoft#18025.
- First, read onnxruntime_test_distributed.py to get familiar with the
input/output of DistributedReshape.
- Second, check the new APIs in reshape.h/reshape.cc to expose CUDA
Reshape kernel to DistributedReshape.
- For DistributedReshape, check its `ComputeInternal` for the 3
categories mentioned above.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants