Skip to content

Commit

Permalink
Distributed Reshape Implementation (#18068)
Browse files Browse the repository at this point in the history
This DistributedReshape aims at supporting all sharding patterns
encountered in llama 2. All patterns found are tested in
`TestDistributedReshape` in `onnxruntime_test_distributed.py`. This PR
implements algorithms to compute the categories below.
- All inputs and outputs are replica, so it's computed like a normal
Reshape.
- Two-axis fusion (if any of the inputs and outputs are sharded). This
category convers, e.g., `[batch, seq, hidden] -> [batch x seq, hidden]`.
- Two-axis decomposition (if any of the inputs and outputs are sharded).
This category convers, e.g., `[batch x seq, hidden] -> [batch, seq,
hidden]`.

Review guideline:
- Ignore the changes in sharding_spec.h and sharding_spec.cc since they
come from another PR #18025.
- First, read onnxruntime_test_distributed.py to get familiar with the
input/output of DistributedReshape.
- Second, check the new APIs in reshape.h/reshape.cc to expose CUDA
Reshape kernel to DistributedReshape.
- For DistributedReshape, check its `ComputeInternal` for the 3
categories mentioned above.
  • Loading branch information
wschin authored Oct 27, 2023
1 parent b79ea74 commit 9c32310
Show file tree
Hide file tree
Showing 14 changed files with 1,870 additions and 40 deletions.
3 changes: 2 additions & 1 deletion cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc"
)
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
Expand Down Expand Up @@ -246,4 +247,4 @@
install(TARGETS onnxruntime_providers_cuda
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
2 changes: 2 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ if (NOT onnxruntime_USE_NCCL)
list(APPEND contrib_ops_excluded_files "collective/sharding.cc")
list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc")
endif()

set(provider_excluded_files
Expand Down
Loading

0 comments on commit 9c32310

Please sign in to comment.