Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support per-tensor device mesh at op level #18025

Merged
merged 1 commit into from
Oct 26, 2023
Merged

Conversation

wschin
Copy link
Contributor

@wschin wschin commented Oct 19, 2023

Since Reshape may change device mesh from, e.g., [0, 1] to [0, 1, 0, 1], we can't assume same device mesh per op. At each operator, we replace a single operator-level device mesh

  • device_mesh_shapes
  • device_mesh_elements

with per-tensor device meshes

  • input_device_mesh_shapes (input_device_mesh_shapes[i] is the device mesh's shape for the i-th input, e.g., "[3]" for 1-D mesh with 3 devices)
  • input_device_mesh_elements (input_device_mesh_elements[i] is the flattened device mesh elements for the i-th input; e.g., "[0, 1, 2]" if you have 3 devices in that mesh)
  • output_device_mesh_shapes
  • output_device_mesh_elements

Check out the change in onnxruntime_test_distributed.py for examples. It's also heavily used in #18068's onnxruntime_test_distributed.py change.

@wschin wschin requested a review from souptc October 19, 2023 02:43
@wschin wschin force-pushed the wechi/per-tensor-mesh branch 2 times, most recently from db85449 to 10bb2fd Compare October 19, 2023 18:23
@wschin wschin closed this Oct 24, 2023
@wschin wschin reopened this Oct 24, 2023
@wschin wschin force-pushed the wechi/per-tensor-mesh branch from 10bb2fd to 1641874 Compare October 24, 2023 17:49
@wschin wschin closed this Oct 25, 2023
@wschin wschin reopened this Oct 25, 2023
@wschin wschin force-pushed the wechi/per-tensor-mesh branch from 1641874 to d358409 Compare October 25, 2023 03:02
@wschin wschin changed the title Support per-tensor device mesh at op level. Support per-tensor device mesh at op level Oct 25, 2023
@wschin wschin force-pushed the wechi/per-tensor-mesh branch 2 times, most recently from ab6424e to 85e3087 Compare October 25, 2023 20:58
Since Reshape may change device mesh from, e.g., [0, 1]
to [0, 1, 0, 1], we can't assume since device mesh per op.

Lint

resolve style warnings
@wschin wschin force-pushed the wechi/per-tensor-mesh branch from 85e3087 to 40f985e Compare October 26, 2023 17:23
@wschin wschin merged commit a514a68 into main Oct 26, 2023
88 of 90 checks passed
@wschin wschin deleted the wechi/per-tensor-mesh branch October 26, 2023 21:47
souptc pushed a commit that referenced this pull request Oct 27, 2023
This DistributedReshape aims at supporting all sharding patterns
encountered in llama 2. All patterns found are tested in
`TestDistributedReshape` in `onnxruntime_test_distributed.py`. This PR
implements algorithms to compute the categories below.
- All inputs and outputs are replica, so it's computed like a normal
Reshape.
- Two-axis fusion (if any of the inputs and outputs are sharded). This
category convers, e.g., `[batch, seq, hidden] -> [batch x seq, hidden]`.
- Two-axis decomposition (if any of the inputs and outputs are sharded).
This category convers, e.g., `[batch x seq, hidden] -> [batch, seq,
hidden]`.

Review guideline:
- Ignore the changes in sharding_spec.h and sharding_spec.cc since they
come from another PR #18025.
- First, read onnxruntime_test_distributed.py to get familiar with the
input/output of DistributedReshape.
- Second, check the new APIs in reshape.h/reshape.cc to expose CUDA
Reshape kernel to DistributedReshape.
- For DistributedReshape, check its `ComputeInternal` for the 3
categories mentioned above.
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
Since Reshape may change device mesh from, e.g., [0, 1] to [0, 1, 0, 1],
we can't assume same device mesh per op. At each operator, we replace a
single operator-level device mesh
- `device_mesh_shapes`
- `device_mesh_elements`

with per-tensor device meshes
- `input_device_mesh_shapes` (input_device_mesh_shapes[i] is the device
mesh's shape for the i-th input, e.g., "[3]" for 1-D mesh with 3
devices)
- `input_device_mesh_elements` (input_device_mesh_elements[i] is the
flattened device mesh elements for the i-th input; e.g., "[0, 1, 2]" if
you have 3 devices in that mesh)
- `output_device_mesh_shapes`
- `output_device_mesh_elements`

Check out the change in onnxruntime_test_distributed.py for examples.
It's also heavily used in microsoft#18068's `onnxruntime_test_distributed.py`
change.
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
This DistributedReshape aims at supporting all sharding patterns
encountered in llama 2. All patterns found are tested in
`TestDistributedReshape` in `onnxruntime_test_distributed.py`. This PR
implements algorithms to compute the categories below.
- All inputs and outputs are replica, so it's computed like a normal
Reshape.
- Two-axis fusion (if any of the inputs and outputs are sharded). This
category convers, e.g., `[batch, seq, hidden] -> [batch x seq, hidden]`.
- Two-axis decomposition (if any of the inputs and outputs are sharded).
This category convers, e.g., `[batch x seq, hidden] -> [batch, seq,
hidden]`.

Review guideline:
- Ignore the changes in sharding_spec.h and sharding_spec.cc since they
come from another PR microsoft#18025.
- First, read onnxruntime_test_distributed.py to get familiar with the
input/output of DistributedReshape.
- Second, check the new APIs in reshape.h/reshape.cc to expose CUDA
Reshape kernel to DistributedReshape.
- For DistributedReshape, check its `ComputeInternal` for the 3
categories mentioned above.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants