From 4fa64459d03a17839ec49d1081e9c7e15e0c7f52 Mon Sep 17 00:00:00 2001 From: liangan1 Date: Thu, 7 Mar 2024 17:08:23 +0800 Subject: [PATCH] IPEX Tensor Parallel (#2435) --- .gitmodules | 4 + cmake/Modules/FindoneCCL.cmake | 36 ++ cmake/cpu/BuildFlags.cmake | 3 +- cmake/cpu/Options.cmake | 10 + csrc/cpu/CMakeLists.txt | 24 +- .../aten/CollectiveCommunicationPrimitive.cpp | 37 ++ .../aten/CollectiveCommunicationPrimitive.h | 30 ++ csrc/cpu/aten/ShmAllReduceAdd.cpp | 32 ++ csrc/cpu/aten/ShmAllReduceAdd.h | 35 ++ .../CollectiveCommunicationPrimitiveKrnl.cpp | 38 ++ csrc/cpu/aten/kernels/SHMAllreduceAddKrnl.cpp | 243 ++++++++++ csrc/cpu/comm/CMakeLists.txt | 5 + csrc/cpu/comm/comm.cpp | 18 + csrc/cpu/comm/comm.h | 9 + csrc/cpu/comm/messager.h | 255 +++++++++++ csrc/cpu/comm/shm_reduction.h | 190 ++++++++ csrc/cpu/jit/passes/graph_rewrite.cpp | 30 +- intel_extension_for_pytorch/cpu/__init__.py | 1 + .../cpu/comm/__init__.py | 8 + .../csrc/cpu/Module.cpp | 6 + .../transformers/__init__.py | 9 + .../transformers/optimize.py | 70 +++ .../transformers/tensor_parallel.py | 425 ++++++++++++++++++ tests/cpu/test_ccl_primitive.py | 55 +++ tests/cpu/test_ipex_tensor_parallel.py | 146 ++++++ third_party/oneCCL | 1 + 26 files changed, 1717 insertions(+), 3 deletions(-) create mode 100644 cmake/Modules/FindoneCCL.cmake create mode 100644 csrc/cpu/aten/CollectiveCommunicationPrimitive.cpp create mode 100644 csrc/cpu/aten/CollectiveCommunicationPrimitive.h create mode 100644 csrc/cpu/aten/ShmAllReduceAdd.cpp create mode 100644 csrc/cpu/aten/ShmAllReduceAdd.h create mode 100644 csrc/cpu/aten/kernels/CollectiveCommunicationPrimitiveKrnl.cpp create mode 100644 csrc/cpu/aten/kernels/SHMAllreduceAddKrnl.cpp create mode 100644 csrc/cpu/comm/CMakeLists.txt create mode 100644 csrc/cpu/comm/comm.cpp create mode 100644 csrc/cpu/comm/comm.h create mode 100644 csrc/cpu/comm/messager.h create mode 100644 csrc/cpu/comm/shm_reduction.h create mode 100644 intel_extension_for_pytorch/cpu/comm/__init__.py create mode 100644 intel_extension_for_pytorch/transformers/tensor_parallel.py create mode 100644 tests/cpu/test_ccl_primitive.py create mode 100644 tests/cpu/test_ipex_tensor_parallel.py create mode 160000 third_party/oneCCL diff --git a/.gitmodules b/.gitmodules index 9dc1e0b09..bd03efbde 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,6 +7,10 @@ [submodule "third_party/libxsmm"] path = third_party/libxsmm url = https://github.com/libxsmm/libxsmm.git +[submodule "third_party/oneCCL"] + path = third_party/oneCCL + url = https://github.com/oneapi-src/oneCCL [submodule "third_party/sleef"] path = third_party/sleef url = https://github.com/shibatch/sleef.git + diff --git a/cmake/Modules/FindoneCCL.cmake b/cmake/Modules/FindoneCCL.cmake new file mode 100644 index 000000000..eeeccc780 --- /dev/null +++ b/cmake/Modules/FindoneCCL.cmake @@ -0,0 +1,36 @@ +# - Try to find oneCCL +# +# The following are set after configuration is done: +# ONECCL_FOUND : set to true if oneCCL is found. +# ONECCL_INCLUDE_DIRS : path to oneCCL include dir. +# ONECCL_LIBRARIES : list of libraries for oneCCL +# +# and the following imported targets: +# +# oneCCL + +IF (NOT ONECCL_FOUND) +SET(ONECCL_FOUND OFF) +SET(ONECCL_LIBRARIES) +SET(ONECCL_INCLUDE_DIRS) + +SET(ONECCL_ROOT "${PROJECT_SOURCE_DIR}/third_party/oneCCL") + +IF(BUILD_NO_ONECCL_PACKAGE) + ADD_SUBDIRECTORY(${ONECCL_ROOT} oneCCL EXCLUDE_FROM_ALL) +ELSE() + ADD_SUBDIRECTORY(${ONECCL_ROOT} build) +ENDIF() + +IF(NOT TARGET ccl) + MESSAGE(FATAL_ERROR "Failed to find oneCCL target") +ENDIF() +add_library(oneCCL ALIAS ccl) + +GET_TARGET_PROPERTY(INCLUDE_DIRS oneCCL INCLUDE_DIRECTORIES) +SET(ONECCL_INCLUDE_DIRS ${INCLUDE_DIRS}) +SET(ONECCL_LIBRARIES oneCCL) + +find_package_handle_standard_args(oneCCL FOUND_VAR ONECCL_FOUND REQUIRED_VARS ONECCL_LIBRARIES ONECCL_INCLUDE_DIRS) + +ENDIF(NOT ONECCL_FOUND) \ No newline at end of file diff --git a/cmake/cpu/BuildFlags.cmake b/cmake/cpu/BuildFlags.cmake index b129d00ce..9b4f336e1 100644 --- a/cmake/cpu/BuildFlags.cmake +++ b/cmake/cpu/BuildFlags.cmake @@ -13,7 +13,8 @@ if(env_cxx_standard GREATER -1) endif() set(CMAKE_CXX_STANDARD 17) set(CMAKE_C_STANDARD 11) -set(CMAKE_CXX_EXTENSIONS OFF) +#oneCCL build only support the gnu standard +set(CMAKE_CXX_EXTENSIONS ON) if(MSVC) set(CMAKE_COMPILE_WARNING_AS_ERROR OFF) diff --git a/cmake/cpu/Options.cmake b/cmake/cpu/Options.cmake index bfb41418b..39d556e80 100644 --- a/cmake/cpu/Options.cmake +++ b/cmake/cpu/Options.cmake @@ -13,6 +13,14 @@ if(WIN32) set(USE_LIBXSMM ON) endif() +option(USE_CCL "Enable oneCCL in IPEX" ON) +option(USE_SHM "Enable shared memory communication in IPEX" ON) +if(WIN32) + set(USE_SHM OFF) +endif() +#set USE_SHM to OFF if USE_CCL is OFF + + cmake_dependent_option(BUILD_STATIC_ONEMKL "Static link with oneMKL" OFF "BUILD_WITH_XPU" ON) function (print_cpu_config_summary) @@ -49,6 +57,8 @@ function (print_cpu_config_summary) message(STATUS " IPEX_DISP_OP : ${IPEX_DISP_OP}") message(STATUS " BUILD_XSMM_VIA_CMAKE : ${BUILD_LIBXSMM_VIA_CMAKE}") message(STATUS " USE_LIBXSMM : ${USE_LIBXSMM}") + message(STATUS " USE_CCL : ${USE_CCL}") + message(STATUS " USE_SHM : ${USE_SHM}") message(STATUS "") message(STATUS "********************************") endfunction() diff --git a/csrc/cpu/CMakeLists.txt b/csrc/cpu/CMakeLists.txt index 187f1a248..8a9eb7624 100644 --- a/csrc/cpu/CMakeLists.txt +++ b/csrc/cpu/CMakeLists.txt @@ -11,6 +11,16 @@ set(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE) set(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE) #find_package(TorchCCL REQUIRED) +# Find OneCCL Lib +set(DEPENDS_LIB) +if(USE_CCL) + include(${IPEX_ROOT_DIR}/cmake/Modules/FindoneCCL.cmake) + # Find OneCCL Lib + link_directories(${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/oneCCL/deps/mpi/lib) + find_package(oneCCL REQUIRED) + list(APPEND DEPENDS_LIB oneCCL) + list(APPEND DEPENDS_LIB mpi) +endif() # TODO: Once llga is merged into oneDNN, use oneDNN directly as the third_party of IPEX # use the oneDNN in llga temporarily: third_party/llga/third_party/oneDNN @@ -34,6 +44,14 @@ if(USE_LIBXSMM) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_LIBXSMM") endif(USE_LIBXSMM) +if(USE_CCL) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_CCL") +endif(USE_CCL) + +if(USE_SHM) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_SHM") +endif(USE_SHM) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_IPEX_MAIN_LIB") # ---[ Main build @@ -73,6 +91,7 @@ add_subdirectory(${IPEX_CPU_ROOT_DIR}/isa) add_subdirectory(${IPEX_CPU_ROOT_DIR}/toolkit) add_subdirectory(${IPEX_CPU_ROOT_DIR}/runtime) add_subdirectory(${IPEX_CPU_ROOT_DIR}/utils) +add_subdirectory(${IPEX_CPU_ROOT_DIR}/comm) add_subdirectory(${IPEX_CPU_ROOT_DIR}/jit) @@ -84,7 +103,8 @@ if(USE_LIBXSMM) endif(USE_LIBXSMM) set(IPEX_CPU_CPP_SRCS ${IPEX_CPU_CPP_DYNDISP_SRCS} ${IPEX_CPU_CPP_ISA_SRCS_GEN} ${IPEX_CPU_CPP_UTILS_SRCS} ${IPEX_CPU_CPP_QUANTIZATION_SRCS} ${IPEX_CPU_CPP_JIT_SRCS} ${IPEX_JIT_COMMON_CPP_SRCS} - ${IPEX_CPU_CPP_ISA_SRCS} ${IPEX_CPU_CPP_IDEEP_SRCS} ${IPEX_CPU_CPP_AUTOCAST_SRCS} ${IPEX_CPU_CPP_ATEN_SRCS} ${IPEX_CPU_CPP_RUNTIME_SRCS} ${IPEX_CPU_CPP_TOOLKIT_SRCS} ${IPEX_UTLIS_CPP_SRCS} ${IPEX_CPU_CPP_TPP_SRCS}) + ${IPEX_CPU_CPP_ISA_SRCS} ${IPEX_CPU_CPP_IDEEP_SRCS} ${IPEX_CPU_CPP_AUTOCAST_SRCS} ${IPEX_CPU_CPP_ATEN_SRCS} ${IPEX_CPU_CPP_RUNTIME_SRCS} ${IPEX_CPU_CPP_TOOLKIT_SRCS} ${IPEX_UTLIS_CPP_SRCS} + ${IPEX_CPU_CPP_TPP_SRCS} ${IPEX_CPU_CPP_COMM_SRCS}) list(REMOVE_ITEM IPEX_CPU_CPP_SRCS ${IPEX_CPU_CPP_ISA_SRCS_ORIGIN}) @@ -123,6 +143,7 @@ target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${ONEDNN_GENERATED_INCLUDE} target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/include) target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${PYTHON_INCLUDE_DIR}) +target_link_libraries(${PLUGIN_NAME_CPU} PUBLIC ${DEPENDS_LIB}) include(${IPEX_ROOT_DIR}/cmake/ClangFormat.cmake) if(CLANG_FORMAT) @@ -221,6 +242,7 @@ if(BUILD_STRIPPED_BIN) set_target_properties(${PLUGIN_NAME_CPU} PROPERTIES LINK_FLAGS_RELEASE -s) endif() + install(TARGETS ${PLUGIN_NAME_CPU} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/csrc/cpu/aten/CollectiveCommunicationPrimitive.cpp b/csrc/cpu/aten/CollectiveCommunicationPrimitive.cpp new file mode 100644 index 000000000..0d9cde360 --- /dev/null +++ b/csrc/cpu/aten/CollectiveCommunicationPrimitive.cpp @@ -0,0 +1,37 @@ +#include "CollectiveCommunicationPrimitive.h" +#include +#include +#include + +namespace torch_ipex { +namespace cpu { + +IPEX_DEFINE_DISPATCH(all_reduce_add_kernel_stub); +IPEX_DEFINE_DISPATCH(allgather_kernel_stub); + +at::Tensor all_reduce_add(at::Tensor t_in) { + RECORD_FUNCTION("ipex::all_reduce_add", c10::ArrayRef({})); + return all_reduce_add_kernel_stub(kCPU, t_in); +} + +at::Tensor allgather( + at::Tensor t_in, + std::vector cols_per_rank, + int64_t world_size) { + RECORD_FUNCTION("ipex::allgather", c10::ArrayRef({})); + return allgather_kernel_stub(kCPU, t_in, cols_per_rank, world_size); +} + +} // namespace cpu +} // namespace torch_ipex + +namespace { + +TORCH_LIBRARY_FRAGMENT(torch_ipex, m) { + m.def("all_reduce_add(Tensor(a!) t_in)-> (Tensor)"); + m.impl( + "all_reduce_add", c10::DispatchKey::CPU, torch_ipex::cpu::all_reduce_add); + m.def("allgather(Tensor input, int[] output, int world_size) -> (Tensor)"); + m.impl("allgather", c10::DispatchKey::CPU, torch_ipex::cpu::allgather); +} +} // namespace diff --git a/csrc/cpu/aten/CollectiveCommunicationPrimitive.h b/csrc/cpu/aten/CollectiveCommunicationPrimitive.h new file mode 100644 index 000000000..1938f64ee --- /dev/null +++ b/csrc/cpu/aten/CollectiveCommunicationPrimitive.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +namespace torch_ipex { +namespace cpu { + +namespace { + +at::Tensor all_reduce_add(at::Tensor& t_in); +at::Tensor allgather( + at::Tensor t_in, + std::vector cols_per_rank, + int64_t world_size); +int64_t get_world_size(const at::Tensor dummy_input); +int64_t get_rank(const at::Tensor dummy_input); +} // namespace + +using all_reduce_add_fn = at::Tensor (*)(at::Tensor& t_in); +using allgather_fn = at::Tensor (*)( + at::Tensor t_in, + std::vector cols_per_rank, + int64_t world_size); + +IPEX_DECLARE_DISPATCH(all_reduce_add_fn, all_reduce_add_kernel_stub); +IPEX_DECLARE_DISPATCH(allgather_fn, allgather_kernel_stub); + +} // namespace cpu +} // namespace torch_ipex diff --git a/csrc/cpu/aten/ShmAllReduceAdd.cpp b/csrc/cpu/aten/ShmAllReduceAdd.cpp new file mode 100644 index 000000000..55c5d294d --- /dev/null +++ b/csrc/cpu/aten/ShmAllReduceAdd.cpp @@ -0,0 +1,32 @@ + +#include "ShmAllReduceAdd.h" +#include +#include +#include + +namespace torch_ipex { +namespace cpu { + +IPEX_DEFINE_DISPATCH(shm_all_reduce_add_kernel_stub); + +at::Tensor shm_all_reduce_add_forward_cpu( + at::Tensor& t_in, + at::Tensor& t_address, + at::Tensor& t_state, + at::Tensor& t_blockState, + int64_t shm_block_size, + int64_t rank, + int64_t world_size) { + return shm_all_reduce_add_kernel_stub( + kCPU, + t_in, + t_address, + t_state, + t_blockState, + shm_block_size, + rank, + world_size); +} + +} // namespace cpu +} // namespace torch_ipex diff --git a/csrc/cpu/aten/ShmAllReduceAdd.h b/csrc/cpu/aten/ShmAllReduceAdd.h new file mode 100644 index 000000000..d9a2f8cf9 --- /dev/null +++ b/csrc/cpu/aten/ShmAllReduceAdd.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +namespace torch_ipex { +namespace cpu { + +namespace { + +at::Tensor shm_all_reduce_add( + at::Tensor& t_in, + at::Tensor& t_address, + at::Tensor& t_state, + at::Tensor& t_blockState, + int64_t shm_block_size, + int64_t rank, + int64_t world_size); +} + +using shm_all_reduce_add_kernel_fn = at::Tensor (*)( + at::Tensor& t_in, + at::Tensor& t_address, + at::Tensor& t_state, + at::Tensor& t_blockState, + int64_t shm_block_size, + int64_t rank, + int64_t world_size); + +IPEX_DECLARE_DISPATCH( + shm_all_reduce_add_kernel_fn, + shm_all_reduce_add_kernel_stub); + +} // namespace cpu +} // namespace torch_ipex diff --git a/csrc/cpu/aten/kernels/CollectiveCommunicationPrimitiveKrnl.cpp b/csrc/cpu/aten/kernels/CollectiveCommunicationPrimitiveKrnl.cpp new file mode 100644 index 000000000..132d9bc01 --- /dev/null +++ b/csrc/cpu/aten/kernels/CollectiveCommunicationPrimitiveKrnl.cpp @@ -0,0 +1,38 @@ +#include +#include +#include +#include +#include + +namespace torch_ipex { +namespace cpu { + +namespace { +at::Tensor all_reduce_add_kernel_impl(at::Tensor& t_in) { + Messenger::getInstance().reduceAdd(t_in); + return t_in; +} + +at::Tensor allgather_kernel_impl( + at::Tensor t_in, + std::vector cols_per_rank, + int64_t world_size) { + std::vector output_tensors; + auto shape = t_in.contiguous().sizes(); + for (int64_t rank = 0; rank < world_size; rank++) { + std::vector t_out_shape(shape.begin(), shape.end() - 1); + t_out_shape.push_back(cols_per_rank[rank + 1] - cols_per_rank[rank]); + output_tensors.push_back(at::empty(t_out_shape, t_in.options())); + } + + return Messenger::getInstance().allgather(t_in, output_tensors); +} + +} // anonymous namespace + +IPEX_REGISTER_DISPATCH(all_reduce_add_kernel_stub, &all_reduce_add_kernel_impl); + +IPEX_REGISTER_DISPATCH(allgather_kernel_stub, &allgather_kernel_impl); + +} // namespace cpu +} // namespace torch_ipex diff --git a/csrc/cpu/aten/kernels/SHMAllreduceAddKrnl.cpp b/csrc/cpu/aten/kernels/SHMAllreduceAddKrnl.cpp new file mode 100644 index 000000000..9e11d0bba --- /dev/null +++ b/csrc/cpu/aten/kernels/SHMAllreduceAddKrnl.cpp @@ -0,0 +1,243 @@ +#include +#include +#include +#include +#include +#include "vec/vec.h" + +namespace torch_ipex { +namespace cpu { + +namespace { + +enum shm_state { INIT = 0, RANK0_COPY = 1, RANKX_COPY_ADD = 2, BROADCAST = 3 }; +enum shm_block_state { INIT_BLOCK = 0, COPY_ADD_DONE_BLOCK = 1 }; + +inline void wait_state_until( + int* states_ptr, + const int index, + enum shm_state state) { + volatile int* state_ptr = states_ptr + index; + while (*state_ptr != state) + _mm_pause(); +} + +inline void wait_block_until( + uint8_t* block_states_ptr, + const int index, + enum shm_block_state state) { + volatile uint8_t* state_ptr = block_states_ptr + index; + while (*state_ptr != state) + _mm_pause(); +} + +template +static inline void multiThreadCopy(DST_T* dst, SRC_T* src, int size) { + RECORD_FUNCTION("multiThreadCopy", c10::ArrayRef({})); + constexpr int sizePerSplit = 512; + int splits = (size + sizePerSplit - 1) / sizePerSplit; +#pragma omp parallel for + for (int i = 0; i < splits; ++i) { + int block_size = + (i == splits - 1) ? (size - i * sizePerSplit) : sizePerSplit; + torch_ipex::cpu::kernel::move_ker( + dst + i * sizePerSplit, src + i * sizePerSplit, block_size); + } +} + +/** + * @brief Performs reduction operation by adding the elements in the send + * buffer in every rank into the shared memory buffer and storing the result + * in the receive buffer. Firstly, the elements in the send buffer are + * copied into the shared memory buffer. Then, the elements in the shared + * memory buffer are added together and stored in the receive buffer. They + * are 4 state to be maintained in the shared memory buffer: 0: ready for + * all-reduce, e.g, initialized or last round all-reduce finished; 1: rank-0 + * copy ready; 2: finish add for other ranks; 3: finish broadcast + * @tparam T The data type of the elements in the buffers. + * @param sendBuf Pointer to the send buffer. + * @param recvBuf Pointer to the receive buffer. + * @param t_address The tensor of the shared memory buffer. + * @param t_state The tensor of the state. + * @param t_blockState The tensor of the block state. + * @param shm_block_size The size of each block in the shared memory buffer. + * @param size The number of elements in the buffers. + * @param element_size The size of each element in bytes. + * @param rank The rank of the current process. + * @param rankSize The total number of processes. + */ +template +void reduceAdd_impl( + T* sendBuf, + T* recvBuf, + at::Tensor t_address, + at::Tensor t_state, + at::Tensor t_blockState, + int shm_block_size, + unsigned long size, + int element_size, + int rank, + int rankSize) { + int nbytes = size * element_size; + int nBlockBytes = shm_block_size * element_size; + int nblocks = (size + shm_block_size - 1) / shm_block_size; + int nthreads = std::min(nblocks, omp_get_max_threads()); + float* address = (float*)t_address.data_ptr(); + uint8_t* block_states_ptr = (uint8_t*)t_blockState.data_ptr(); + int* states_ptr = t_state.data_ptr(); + { + RECORD_FUNCTION( + "ipex::shm_all_reduce_add::rank0_copy", c10::ArrayRef({})); + if (rank == 0) { + for (int i = 1; i < rankSize; i++) { + wait_state_until(states_ptr, i, INIT); + } + multiThreadCopy(address, sendBuf, size); + + } else { + wait_state_until(states_ptr, rank, INIT); + wait_state_until(states_ptr, 0, RANK0_COPY); + } + } + std::atomic_thread_fence(std::memory_order_release); + states_ptr[rank] = RANK0_COPY; + { + RECORD_FUNCTION( + "ipex::shm_all_reduce_add::copy_add_rankx", + c10::ArrayRef({})); + if (rank != 0) { +#pragma omp parallel for num_threads(nthreads) + for (int blockIndex = 0; blockIndex < nblocks; blockIndex++) { + auto lSendBuf = sendBuf + shm_block_size * blockIndex; + auto lAddrBuf = address + shm_block_size * blockIndex; + int realBlockSize = + (blockIndex == (nblocks - 1) + ? (size - shm_block_size * (nblocks - 1)) + : shm_block_size); + + if (rank != 1) { + wait_block_until( + block_states_ptr, + blockIndex * rankSize + rank - 1, + COPY_ADD_DONE_BLOCK); + } + torch_ipex::cpu::kernel::add_ker( + lAddrBuf, lSendBuf, realBlockSize); + std::atomic_thread_fence(std::memory_order_release); + block_states_ptr[blockIndex * rankSize + rank - 1] = INIT_BLOCK; + std::atomic_thread_fence(std::memory_order_release); + block_states_ptr[blockIndex * rankSize + rank] = COPY_ADD_DONE_BLOCK; + } + std::atomic_thread_fence(std::memory_order_release); + states_ptr[rank] = RANKX_COPY_ADD; + } + } + { + RECORD_FUNCTION( + "ipex::shm_all_reduce_add::broadcast", c10::ArrayRef({})); + wait_state_until(states_ptr, rankSize - 1, RANKX_COPY_ADD); + multiThreadCopy(recvBuf, address, size); + if (rank == rankSize - 1) { + for (int i = 0; i < rankSize - 1; i++) { + wait_state_until(states_ptr, i, BROADCAST); + } + + for (int i = 0; i < rankSize; i++) { + std::atomic_thread_fence(std::memory_order_release); + states_ptr[i] = INIT; + } + } else { + std::atomic_thread_fence(std::memory_order_release); + states_ptr[rank] = BROADCAST; + } + } +} + +at::Tensor shm_all_reduce_add_kernel_impl( + at::Tensor& t_in, + at::Tensor& t_address, + at::Tensor& t_state, + at::Tensor& t_blockState, + int64_t shm_block_size, + int64_t rank, + int64_t world_size) { + RECORD_FUNCTION("ipex::shm_all_reduce_add", c10::ArrayRef({})); + // torch_ipex::cpu::shm_all_reduce_add_kernel_stub(kCPU, t_in); + auto dtype = t_in.scalar_type(); + if (dtype == at::ScalarType::BFloat16) { + reduceAdd_impl( + (at::BFloat16*)t_in.data_ptr(), + (at::BFloat16*)t_in.data_ptr(), + t_address, + t_state, + t_blockState, + shm_block_size, + t_in.numel(), + sizeof(at::BFloat16), + rank, + world_size); + } else if (dtype == at::ScalarType::Half) { + reduceAdd_impl( + (at::Half*)t_in.data_ptr(), + (at::Half*)t_in.data_ptr(), + t_address, + t_state, + t_blockState, + shm_block_size, + t_in.numel(), + sizeof(at::Half), + rank, + world_size); + } else if (dtype == at::ScalarType::Float) { + reduceAdd_impl( + (float*)t_in.data_ptr(), + (float*)t_in.data_ptr(), + t_address, + t_state, + t_blockState, + shm_block_size, + t_in.numel(), + sizeof(float), + rank, + world_size); + } else if (dtype == at::ScalarType::Int) { + reduceAdd_impl( + (int*)t_in.data_ptr(), + (int*)t_in.data_ptr(), + t_address, + t_state, + t_blockState, + shm_block_size, + t_in.numel(), + sizeof(int), + rank, + world_size); + } else if (dtype == at::ScalarType::Long) { + reduceAdd_impl( + (int64_t*)t_in.data_ptr(), + (int64_t*)t_in.data_ptr(), + t_address, + t_state, + t_blockState, + shm_block_size, + t_in.numel(), + sizeof(int64_t), + rank, + world_size); + } else { + TORCH_CHECK( + false, + "Data Type %s is not supported in SHM based all-reduce!\n", + typeid(dtype).name()); + exit(-1); + } + return t_in; +} +} // namespace + +IPEX_REGISTER_DISPATCH( + shm_all_reduce_add_kernel_stub, + &shm_all_reduce_add_kernel_impl); + +} // namespace cpu +} // namespace torch_ipex diff --git a/csrc/cpu/comm/CMakeLists.txt b/csrc/cpu/comm/CMakeLists.txt new file mode 100644 index 000000000..086fabd9d --- /dev/null +++ b/csrc/cpu/comm/CMakeLists.txt @@ -0,0 +1,5 @@ +FILE(GLOB _COMM_SRCS *.cpp) +LIST(APPEND IPEX_CPU_CPP_COMM_SRCS ${_COMM_SRCS}) + +# Pass to parent +set(IPEX_CPU_CPP_COMM_SRCS ${IPEX_CPU_CPP_COMM_SRCS} PARENT_SCOPE) diff --git a/csrc/cpu/comm/comm.cpp b/csrc/cpu/comm/comm.cpp new file mode 100644 index 000000000..b3c062675 --- /dev/null +++ b/csrc/cpu/comm/comm.cpp @@ -0,0 +1,18 @@ +#include "comm.h" +#include "messager.h" + +namespace torch_ipex { +namespace cpu { +int get_rank() { + return Messenger::getInstance().getRank(); +} + +int get_world_size() { + return Messenger::getInstance().getSize(); +} + +void barrier() { + Messenger::getInstance().barrier(); +} +} // namespace cpu +} // namespace torch_ipex diff --git a/csrc/cpu/comm/comm.h b/csrc/cpu/comm/comm.h new file mode 100644 index 000000000..790b4ef46 --- /dev/null +++ b/csrc/cpu/comm/comm.h @@ -0,0 +1,9 @@ +#pragma once + +namespace torch_ipex { +namespace cpu { +void barrier(); +int get_world_size(); +int get_rank(); +} // namespace cpu +} // namespace torch_ipex \ No newline at end of file diff --git a/csrc/cpu/comm/messager.h b/csrc/cpu/comm/messager.h new file mode 100644 index 000000000..7f9ad5fa0 --- /dev/null +++ b/csrc/cpu/comm/messager.h @@ -0,0 +1,255 @@ +#pragma once +#include + +#include +#include +#include +#include "oneapi/ccl.hpp" +#ifdef USE_SHM +#include "shm_reduction.h" +#endif + +class Messenger { + private: + Messenger() { + // User has set the SINGLE_INSTANCE environment variable + // or program is not with MPI. + if (std::getenv("SINGLE_INSTANCE") != nullptr || !withMpirun()) { + std::cout << "[INFO] SINGLE_INSTANCE MODE." << std::endl; + this->pcomm = nullptr; +#ifdef USE_SHM + this->pshm = nullptr; +#endif + this->rank = 0; + this->size = 1; + return; + } + + int flag = 0; + MPI_Initialized(&flag); + if (flag) { + MPI_Finalize(); + } + ccl::init(); + MPI_Init(NULL, NULL); + MPI_Comm_size(MPI_COMM_WORLD, &size); + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + atexit(Messenger::mpi_finalize); + + if (rank == 0) { + kvs = ccl::create_main_kvs(); + main_addr = kvs->get_address(); + MPI_Bcast( + (void*)main_addr.data(), + main_addr.size(), + MPI_BYTE, + 0, + MPI_COMM_WORLD); + } else { + MPI_Bcast( + (void*)main_addr.data(), + main_addr.size(), + MPI_BYTE, + 0, + MPI_COMM_WORLD); + kvs = ccl::create_kvs(main_addr); + } + + pcomm = new ccl::communicator(ccl::create_communicator(size, rank, kvs)); + + rank = pcomm->rank(); + size = pcomm->size(); + +#ifdef USE_SHM + char my_hostname[MPI_MAX_PROCESSOR_NAME]; + char all_hostnames[MPI_MAX_PROCESSOR_NAME * MPI_MAX_PROCESSOR_NAME]; + int hostname_len; + + // Check ranks are on the same physical machine + MPI_Get_processor_name(my_hostname, &hostname_len); + MPI_Allgather( + my_hostname, + MPI_MAX_PROCESSOR_NAME, + MPI_CHAR, + all_hostnames, + MPI_MAX_PROCESSOR_NAME, + MPI_CHAR, + MPI_COMM_WORLD); + + int same_hostnames = 1; + for (int i = 1; i < size; i++) { + if (strcmp(my_hostname, &all_hostnames[i * MPI_MAX_PROCESSOR_NAME]) != + 0) { + same_hostnames = 0; + break; + } + } + + if (same_hostnames) { + pshm = new ShmReduction(rank, size, [this](int* pid_fd, size_t count) { + this->broadcast(pid_fd, count); + }); + } else { + pshm = nullptr; + } +#endif + } + + ~Messenger() { + delete pcomm; +#ifdef USE_SHM + if (pshm != nullptr) + delete pshm; +#endif + } + + ccl::datatype get_ccl_dtype(at::ScalarType dtype) { + if (dtype == at::ScalarType::BFloat16) { + return ccl::datatype::bfloat16; + } else if (dtype == at::ScalarType::Half) { + return ccl::datatype::float16; + } else if (dtype == at::ScalarType::Float) { + return ccl::datatype::float32; + } else if (dtype == at::ScalarType::Int) { + return ccl::datatype::int32; + } else if (dtype == at::ScalarType::Long) { + return ccl::datatype::int64; + } else { + printf("Type %s not supported!\n", typeid(dtype).name()); + exit(-1); + } + } + + void ccl_allreduce_add(at::Tensor& t_in) { + auto ccl_dtype = get_ccl_dtype(t_in.scalar_type()); + ccl::allreduce( + t_in.data_ptr(), + t_in.data_ptr(), + (size_t)t_in.numel(), + ccl_dtype, + ccl::reduction::sum, + *pcomm) + .wait(); + } + + public: + static Messenger& getInstance() { + static Messenger instance; + return instance; + } + + bool isMaster() { + return rank == 0; + } + + int getRank() { + return rank; + } + + int getSize() { + return size; + } + + /** + * Performs a reduction operation by adding the elements of the input tensor. + * If USE_SHM is defined and the size of the tensor exceeds the shared memory + * size or local ranks flag is false, the reduction is performed using the + * ccl_allreduce_add method. Otherwise, the reduction is performed using the + * reduceAdd method of the pshm object which used SHM. If USE_SHM is not + * defined, the reduction is always performed using the ccl_allreduce_add + * method. + * + * @param t_in The input tensor to be reduced. + */ + void reduceAdd(at::Tensor& t_in) { +#ifdef USE_SHM + if (t_in.numel() * sizeof(float) > pshm->getSHMSize() || pshm == nullptr) { + this->ccl_allreduce_add(t_in); + } else { + pshm->reduceAdd(t_in); + } +#else + this->ccl_allreduce_add(t_in); +#endif + } + + at::Tensor allgather( + at::Tensor data, + const std::vector& vec_data_out) { + std::vector recvCounts; + std::transform( + vec_data_out.begin(), + vec_data_out.end(), + std::back_inserter(recvCounts), + [](const at::Tensor& t) { return t.numel(); }); + std::vector recvBufs; + std::transform( + vec_data_out.begin(), + vec_data_out.end(), + std::back_inserter(recvBufs), + [](const at::Tensor& t) { return t.data_ptr(); }); + { + RECORD_FUNCTION("ccl::allgatherv", std::vector()); + ccl::allgatherv( + data.data_ptr(), + (size_t)data.numel(), + recvBufs, + recvCounts, + get_ccl_dtype(data.scalar_type()), + *pcomm) + .wait(); + } + return at::cat(vec_data_out, -1); + } + + void barrier() { + if (check()) { + ccl::barrier(*pcomm); + } + } + + void broadcast(int* pid_fd, size_t count) { + if (check()) { + ccl::broadcast(pid_fd, count, ccl::datatype::int32, 0, *pcomm).wait(); + } + } + + bool withMpirun() { + return ( + std::getenv("MPI_LOCALRANKID") || std::getenv("MPI_LOCALNRANKS") || + std::getenv("PMI_RANK") || std::getenv("PMI_SIZE") || + std::getenv("PMIX_RANK")); + } + + private: + Messenger(const Messenger& messenger) = delete; + Messenger& operator=(const Messenger& messenger) = delete; + + static void mpi_finalize() { + int is_finalized = 0; + MPI_Finalized(&is_finalized); + + if (!is_finalized) { + MPI_Finalize(); + } + } + + // Check if indeed need to communicate + bool check() { + return size > 1; + } + + private: + int size; + int rank; + + ccl::shared_ptr_class kvs; + ccl::kvs::address_type main_addr; + + ccl::communicator* pcomm; + +#ifdef USE_SHM + ShmReduction* pshm; +#endif +}; diff --git a/csrc/cpu/comm/shm_reduction.h b/csrc/cpu/comm/shm_reduction.h new file mode 100644 index 000000000..eb1717ff3 --- /dev/null +++ b/csrc/cpu/comm/shm_reduction.h @@ -0,0 +1,190 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "aten/ShmAllReduceAdd.h" + +namespace torch_ipex { +namespace cpu { +#define SHM_NAME "ipex_shm_buffer" +#define SHM_BLOCK_SIZE_L (5120) +#define SHM_BLOCK_SIZE_S (16 * 5120) +#define MAX_SHM_BLOCK_COUNT 4096 +#define MAX_SHM_SIZE (SHM_BLOCK_SIZE_S * MAX_SHM_BLOCK_COUNT * sizeof(float)) + +struct ShmContext { + const char* name; + int fp; + int pid_fd[2]; + int* state; + at::Tensor t_state; + uint8_t* blockState; + at::Tensor t_blockState; + void* address; + at::Tensor t_address; + size_t nstates; + size_t nblocks; + size_t nbytes; +}; + +inline void connect_shm(ShmContext* ctx) { + char fd_path[64]; + snprintf( + fd_path, + sizeof(fd_path), + "/proc/%d/fd/%d", + ctx->pid_fd[0], + ctx->pid_fd[1]); + ctx->fp = open(fd_path, O_RDWR); + if (ctx->fp == -1) { + perror("Bad file descriptor."); + exit(-1); + } + + const int total_size = + ctx->nstates * sizeof(int) + ctx->nbytes + ctx->nblocks * ctx->nstates; + + // Map the shared memory into the address space of the process + void* shm_ptr = + mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED, ctx->fp, 0); + if (shm_ptr == MAP_FAILED) { + perror("shm mmap failed."); + exit(-1); + } + ctx->state = (int*)shm_ptr; + ctx->t_state = + at::from_blob((void*)ctx->state, {(signed long)ctx->nstates}, at::kInt) + .to(at::kCPU); + ctx->blockState = (uint8_t*)((int*)shm_ptr + ctx->nstates); + ctx->t_blockState = + at::from_blob( + (void*)ctx->blockState, + {(signed long)ctx->nblocks, (signed long)ctx->nstates}, + at::kByte) + .to(at::kCPU); + ctx->address = + (void*)((uint8_t*)ctx->blockState + ctx->nblocks * ctx->nstates); + ctx->t_address = at::from_blob( + (void*)ctx->address, + {(signed long)(ctx->nbytes / sizeof(float))}, + at::kFloat) + .to(at::kCPU); +} + +inline void create_shm(ShmContext* ctx) { + ctx->fp = shm_open(ctx->name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR); + + if (ctx->fp == -1) { + perror("shm open failed."); + exit(-1); + } + const int total_size = + ctx->nstates * sizeof(int) + ctx->nbytes + ctx->nblocks * ctx->nstates; + // Truncate the shared memory to the desired size + if (ftruncate(ctx->fp, total_size) == -1) { + perror("shm ftruncate failed."); + exit(-1); + } + + // Map the shared memory into the address space of the process + void* shm_ptr = + mmap(NULL, total_size, PROT_READ | PROT_WRITE, MAP_SHARED, ctx->fp, 0); + if (shm_ptr == MAP_FAILED) { + perror("shm mmap failed."); + exit(-1); + } + ctx->pid_fd[0] = getpid(); + ctx->pid_fd[1] = ctx->fp; + ctx->state = (int*)shm_ptr; + ctx->t_state = + at::from_blob((void*)ctx->state, {(signed long)ctx->nstates}, at::kInt) + .to(at::kCPU); + ctx->blockState = (uint8_t*)((int*)shm_ptr + ctx->nstates); + ctx->t_blockState = + at::from_blob( + (void*)ctx->blockState, + {(signed long)ctx->nblocks, (signed long)ctx->nstates}, + at::kByte) + .to(at::kCPU); + ctx->address = + (void*)((uint8_t*)ctx->blockState + ctx->nblocks * ctx->nstates); + ctx->t_address = at::from_blob( + (void*)ctx->address, + {(signed long)(ctx->nbytes / sizeof(float))}, + at::kFloat) + .to(at::kCPU); +} + +inline void close_shm(ShmContext* ctx) { + const int total_size = ctx->nstates * sizeof(int) + ctx->nbytes; + if (ctx->fp != -1) { + munmap(ctx->address, total_size); + shm_unlink(ctx->name); + } +} + +} // namespace cpu +} // namespace torch_ipex + +class ShmReduction { + public: + ShmReduction(int rank, int size, std::function callback) + : rank_(rank), rank_size_(size) { + shmCtx_.name = SHM_NAME; + shmCtx_.nstates = size; + shmCtx_.nbytes = MAX_SHM_SIZE; + shmCtx_.nblocks = MAX_SHM_BLOCK_COUNT; + if (rank_ == 0) { + torch_ipex::cpu::create_shm(&shmCtx_); + memset(shmCtx_.state, 0, shmCtx_.nstates * sizeof(int)); + memset((void*)shmCtx_.blockState, 0, shmCtx_.nstates * shmCtx_.nblocks); + } + + callback(shmCtx_.pid_fd, 2); + + if (rank != 0) { + torch_ipex::cpu::connect_shm(&shmCtx_); + } + } + + ~ShmReduction() { + torch_ipex::cpu::close_shm(&shmCtx_); + } + + int getSHMSize() { + return MAX_SHM_SIZE; + } + + void reduceAdd(at::Tensor& t_in) { + bool is_small = t_in.numel() < 51200; + auto block_size = is_small ? SHM_BLOCK_SIZE_S : SHM_BLOCK_SIZE_L; + torch_ipex::cpu::shm_all_reduce_add_kernel_stub( + kCPU, + t_in, + shmCtx_.t_address, + shmCtx_.t_state, + shmCtx_.t_blockState, + block_size, + rank_, + rank_size_); + } + + int rank_; + int rank_size_; + + private: + torch_ipex::cpu::ShmContext shmCtx_; +}; diff --git a/csrc/cpu/jit/passes/graph_rewrite.cpp b/csrc/cpu/jit/passes/graph_rewrite.cpp index fbe685f43..e309877af 100644 --- a/csrc/cpu/jit/passes/graph_rewrite.cpp +++ b/csrc/cpu/jit/passes/graph_rewrite.cpp @@ -1363,11 +1363,39 @@ void simplifyAllReduce(std::shared_ptr& graph) { %r = aten::add_(%r5, %fc_out_bias, %alpha) return (%r) )"; - SubgraphRewriter rewriter_v1, rewriter_v2; + std::string all_reduce_v3 = R"( + graph(%a, %weight, %out_features1, %out_features2, %b, %fc_in_weight, %fc_in_bias, %fc_out_weight, %fc_out_bias, %alpha, %idx, %no, %dtype, %zero): + %r1 = torch_ipex::tpp_linear(%a, %weight, %out_features1) + %r2 = torch_ipex::inference_all_reduce_add(%r1) + %r2_1 = aten::to(%b, %idx, %no, %no, %dtype) + %r2_2 = aten::contiguous(%r2_1, %zero) + %r3 = torch_ipex::tpp_linear_gelu(%r2_2, %fc_in_weight, %fc_in_bias, %out_features2) + %r4 = aten::to(%r3, %idx, %no, %no, %dtype) + %r5 = aten::contiguous(%r4, %zero) + %r6 = torch_ipex::tpp_linear_bias(%r5, %fc_out_weight, %fc_out_bias, %out_features1) + %r7 = torch_ipex::inference_all_reduce_add(%r6) + %r = aten::add(%r2, %r7, %alpha) + return (%r) )"; + std::string all_reduce_repl_v3 = R"( + graph(%a, %weight, %out_features1, %out_features2, %b, %fc_in_weight, %fc_in_bias, %fc_out_weight, %fc_out_bias, %alpha, %idx, %no, %dtype, %zero): + %r1 = torch_ipex::tpp_linear(%a, %weight, %out_features1) + %r2_1 = aten::to(%b, %idx, %no, %no, %dtype) + %r2_2 = aten::contiguous(%r2_1, %zero) + %r2 = torch_ipex::tpp_linear_gelu(%r2_2, %fc_in_weight, %fc_in_bias, %out_features2) + %r3 = aten::to(%r2, %idx, %no, %no, %dtype) + %r4 = aten::contiguous(%r3, %zero) + %scale: float = prim::Constant[value=1.0]() + %r5 = torch_ipex::tpp_linear_add(%r4, %r1, %fc_out_weight, %fc_out_bias, %scale, %out_features1) + %r6 = torch_ipex::inference_all_reduce_add(%r5) + return (%r6) )"; + + SubgraphRewriter rewriter_v1, rewriter_v2, rewriter_v3; rewriter_v1.RegisterRewritePattern(all_reduce_v1, all_reduce_repl_v1); rewriter_v2.RegisterRewritePattern(all_reduce_v2, all_reduce_repl_v2); + rewriter_v3.RegisterRewritePattern(all_reduce_v3, all_reduce_repl_v3); rewriter_v1.runOnGraph(graph); rewriter_v2.runOnGraph(graph); + rewriter_v3.runOnGraph(graph); } } // namespace graph_rewrite diff --git a/intel_extension_for_pytorch/cpu/__init__.py b/intel_extension_for_pytorch/cpu/__init__.py index 9b2ff9adf..43930bacf 100644 --- a/intel_extension_for_pytorch/cpu/__init__.py +++ b/intel_extension_for_pytorch/cpu/__init__.py @@ -1,3 +1,4 @@ from . import runtime from . import autocast from . import auto_ipex +from . import comm diff --git a/intel_extension_for_pytorch/cpu/comm/__init__.py b/intel_extension_for_pytorch/cpu/comm/__init__.py new file mode 100644 index 000000000..fc7d249c2 --- /dev/null +++ b/intel_extension_for_pytorch/cpu/comm/__init__.py @@ -0,0 +1,8 @@ +import torch +import intel_extension_for_pytorch._C as torch_ipex_cpp + +get_world_size = torch_ipex_cpp.get_world_size +get_rank = torch_ipex_cpp.get_rank +barrier = torch_ipex_cpp.barrier +allreduce_add = torch.ops.torch_ipex.all_reduce_add +allgather = torch.ops.torch_ipex.allgather diff --git a/intel_extension_for_pytorch/csrc/cpu/Module.cpp b/intel_extension_for_pytorch/csrc/cpu/Module.cpp index 57bad15b8..3852bb961 100644 --- a/intel_extension_for_pytorch/csrc/cpu/Module.cpp +++ b/intel_extension_for_pytorch/csrc/cpu/Module.cpp @@ -36,6 +36,7 @@ #include "TaskModule.h" #include "aten/EmbeddingBag.h" +#include "comm/comm.h" #include "runtime/CPUPool.h" #include "runtime/TaskExecutor.h" #include "toolkit/sklearn.h" @@ -268,6 +269,11 @@ void InitIpexModuleBindings(py::module m) { m.def("tpp_fused_lamb", &torch_ipex::tpp::fused_lamb); m.def("tpp_fused_lamb_v2", &torch_ipex::tpp::fused_lamb_v2); + // communication related + m.def("get_rank", &torch_ipex::cpu::get_rank); + m.def("get_world_size", &torch_ipex::cpu::get_world_size); + m.def("barrier", &torch_ipex::cpu::barrier); + // Module version m.def("_get_mkl_version", []() { return torch_ipex::utils::get_mkl_version(); diff --git a/intel_extension_for_pytorch/transformers/__init__.py b/intel_extension_for_pytorch/transformers/__init__.py index 38638e517..1e17b4597 100644 --- a/intel_extension_for_pytorch/transformers/__init__.py +++ b/intel_extension_for_pytorch/transformers/__init__.py @@ -2,3 +2,12 @@ from .optimize import _set_optimized_model_for_generation from .models.cpu.modules.attentions import _IPEXAttentionCPU from .models.cpu.modules.decoder import _IPEXDecoderLayerCPU +from .tensor_parallel import ( + shard_lm_head_weights, + shard_mha_weights, + shard_mlp_weights, + update_heads_info, + TensorParallelColumnLinear, + TensorParallelRowLinear, + TensorParallelLMhead, +) diff --git a/intel_extension_for_pytorch/transformers/optimize.py b/intel_extension_for_pytorch/transformers/optimize.py index 83b7a2e3b..503decd3d 100644 --- a/intel_extension_for_pytorch/transformers/optimize.py +++ b/intel_extension_for_pytorch/transformers/optimize.py @@ -13,6 +13,13 @@ _convert_woq_with_low_precision_checkpoint, ) +from .tensor_parallel import ( + shard_lm_head_weights, + shard_mha_weights, + shard_mlp_weights, + update_heads_info, +) + def convert_functions(m, target_m, new_function_name, new_function): for _, sub_m in m.named_children(): @@ -300,6 +307,19 @@ def model_convert_reference(_model): except ImportError: # distributed uses default False pass + need_ipex_tp = False + if _model.device.type == "cpu": + from ..cpu import comm as ipex_comm + + world_size = ipex_comm.get_world_size() + rank = ipex_comm.get_rank() + if world_size > 1: + global distributed + if distributed: + need_ipex_tp = False + else: + need_ipex_tp = True + distributed = True # model-wise optimizations - MHA module for supported_mha_class in [ @@ -312,6 +332,26 @@ def model_convert_reference(_model): transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeAttention, transformers.models.t5.modeling_t5.T5Attention, ]: + if need_ipex_tp and supported_mha_class in [ + transformers.models.llama.modeling_llama.LlamaAttention, + transformers.models.gptj.modeling_gptj.GPTJAttention, + ]: + num_heads = _model.config.num_attention_heads + num_kv_heads = num_heads + for name in ["num_key_value_heads"]: + if hasattr(_model.config, name): + num_kv_heads = getattr(_model.config, name) + head_dim = _model.config.hidden_size // num_heads + shard_mha_weights( + _model, + supported_mha_class, + num_heads, + num_kv_heads, + head_dim, + rank, + world_size, + ) + convert_class( _model, supported_mha_class, @@ -319,6 +359,36 @@ def model_convert_reference(_model): _model.config, distributed=distributed, ) + if need_ipex_tp: + for supported_mlp_class in [ + transformers.models.llama.modeling_llama.LlamaMLP, + transformers.models.gptj.modeling_gptj.GPTJMLP, + ]: + shard_mlp_weights( + _model, + supported_mlp_class, + num_heads, + num_kv_heads, + head_dim, + rank, + world_size, + ) + for supported_model_class in [ + transformers.models.llama.modeling_llama.LlamaForCausalLM, + transformers.models.gptj.modeling_gptj.GPTJForCausalLM, + ]: + if isinstance(_model, supported_model_class): + shard_lm_head_weights( + _model, + supported_model_class, + num_heads, + num_kv_heads, + head_dim, + rank, + world_size, + ) + update_heads_info(_model, rank, world_size) + # model-wise optimizations - Feedforward/Decoder layer modules for supported_decoder_class in [ transformers.models.llama.modeling_llama.LlamaDecoderLayer, diff --git a/intel_extension_for_pytorch/transformers/tensor_parallel.py b/intel_extension_for_pytorch/transformers/tensor_parallel.py new file mode 100644 index 000000000..a38181521 --- /dev/null +++ b/intel_extension_for_pytorch/transformers/tensor_parallel.py @@ -0,0 +1,425 @@ +import torch +import torch.nn as nn +from ..cpu import comm as ipex_comm +import os + + +class TensorParallellLinear(nn.Module): + def __init__( + self, + linear, + num_kv_heads, + num_heads, + head_dim, + rank, + world_size, + shard_by_head, + shard_by_col, + ): + super().__init__() + self.num_kv_heads = num_kv_heads + self.num_heads = num_heads + self.head_dim = head_dim + self.rank = rank + self.world_size = world_size + self.shard_by_head = shard_by_head + self.shard_by_col = shard_by_col + self.cols_per_rank = None + self.shard_weights(linear) + + def shard_weights_by_head( + self, + linear, + num_kv_heads, + num_heads, + head_dim, + rank, + world_size, + shard_by_col=True, + ): + if shard_by_col: + total_size = linear.weight.shape[0] + else: + total_size = linear.weight.shape[1] + q_bias = None + k_bias = None + v_bias = None + bias_data = None + concat_qkv = total_size > num_heads * head_dim + kv_group_size = num_heads // num_kv_heads + kv_head_per_rank = num_kv_heads // world_size + if world_size == 1: + return + if world_size > num_kv_heads: + RuntimeError( + f"world_size {world_size} is larger than num_kv_heads {num_kv_heads}" + ) + kv_head_range = [0] # [) + for i in range(world_size - 1, -1, -1): + kv_head_this_rank = kv_head_per_rank + if i < num_kv_heads % world_size: + kv_head_this_rank += 1 + kv_head_range.append(kv_head_range[-1] + kv_head_this_rank) + weight_data = linear.weight.data + q_head_start = kv_head_range[rank] * kv_group_size + q_head_end = ( + q_head_start + + (kv_head_range[rank + 1] - kv_head_range[rank]) * kv_group_size + ) + if shard_by_col: + q = weight_data[q_head_start * head_dim : q_head_end * head_dim] + if linear.bias is not None: + q_bias = linear.bias.data[ + q_head_start * head_dim : q_head_end * head_dim + ] + else: + q = weight_data[:, q_head_start * head_dim : q_head_end * head_dim] + if not concat_qkv: + return torch.nn.Parameter(q), torch.nn.Parameter(q_bias) + + k_head_start = num_heads + kv_head_range[rank] + k_head_end = k_head_start + (kv_head_range[rank + 1] - kv_head_range[rank]) + v_head_start = num_heads + num_kv_heads + kv_head_range[rank] + v_head_end = v_head_start + (kv_head_range[rank + 1] - kv_head_range[rank]) + if shard_by_col: + k = weight_data[k_head_start * head_dim : k_head_end * head_dim] + v = weight_data[v_head_start * head_dim : v_head_end * head_dim] + if linear.bias is not None: + k_bias = linear.bias.data[ + k_head_start * head_dim : k_head_end * head_dim + ] + v_bias = linear.bias.data[ + v_head_start * head_dim : v_head_end * head_dim + ] + bias_data = torch.cat([q_bias, k_bias, v_bias], dim=0) + else: + k = weight_data[:, k_head_start * head_dim : k_head_end * head_dim] + v = weight_data[:, v_head_start * head_dim : v_head_end * head_dim] + if linear.bias is not None: + bias_data = linear.bias.data + weight_data = torch.cat([q, k, v], dim=0) + return torch.nn.Parameter(weight_data), torch.nn.Parameter(bias_data) + + def shard_weights_by_block( + self, linear, rank, world_size, shard_by_col=True, block_size=64 + ): + if shard_by_col: + total_size = linear.weight.shape[0] + else: + total_size = linear.weight.shape[1] + bias_data = None + cols_per_rank = [0] + for i in range(world_size - 1, -1, -1): + if total_size % block_size == 0: + block_count = total_size // block_size + block_per_rank = block_count // world_size + if i < block_count % world_size: + block_per_rank += 1 + cols_per_rank.append(cols_per_rank[-1] + block_per_rank * block_size) + else: + cols = total_size // world_size + if i < total_size % world_size: + cols += 1 + cols_per_rank.append(cols_per_rank[-1] + cols) + weight_data = linear.weight.data + if shard_by_col: + weight_data = weight_data[cols_per_rank[rank] : cols_per_rank[rank + 1]] + if linear.bias is not None: + bias_data = linear.bias.data[ + cols_per_rank[rank] : cols_per_rank[rank + 1] + ] + else: + weight_data = weight_data[:, cols_per_rank[rank] : cols_per_rank[rank + 1]] + if linear.bias is not None: + bias_data = linear.bias.data / float(world_size) + return ( + torch.nn.Parameter(weight_data), + torch.nn.Parameter(bias_data), + cols_per_rank, + ) + + def shard_weights(self, linear): + if self.world_size == 1: + return + if self.shard_by_head: + weight, bias = self.shard_weights_by_head( + linear, + self.num_kv_heads, + self.num_heads, + self.head_dim, + self.rank, + self.world_size, + self.shard_by_col, + ) + else: + weight, bias, self.cols_per_rank = self.shard_weights_by_block( + linear, + self.rank, + self.world_size, + self.shard_by_col, + ) + self.linear = nn.Linear( + weight.shape[1], weight.shape[0], bias=linear.bias is not None + ) + + self.linear.weight = weight + if linear.bias is not None: + self.linear.bias = bias + del linear + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return self.linear(input) + + +class TensorParallelColumnLinear(TensorParallellLinear): + def __init__( + self, + linear, + num_kv_heads, + num_heads, + head_dim, + rank, + world_size, + shard_by_head=True, + ): + super().__init__( + linear, + num_kv_heads, + num_heads, + head_dim, + rank, + world_size, + shard_by_head, + shard_by_col=True, + ) + + +class TensorParallelRowLinear(TensorParallellLinear): + def __init__( + self, + linear, + num_kv_heads, + num_heads, + head_dim, + rank, + world_size, + shard_by_head=True, + ): + super().__init__( + linear, + num_kv_heads, + num_heads, + head_dim, + rank, + world_size, + shard_by_head, + shard_by_col=False, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + out = self.linear(input) + if self.world_size > 1: + ipex_comm.allreduce_add(out) + return out + + +class TensorParallelLMhead(TensorParallellLinear): + def __init__( + self, + linear, + num_kv_heads, + num_heads, + head_dim, + rank, + world_size, + shard_by_col, + ): + super().__init__( + linear, + num_kv_heads, + num_heads, + head_dim, + rank, + world_size, + shard_by_head=False, + shard_by_col=shard_by_col, + ) + self.gather_result = shard_by_col + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.gather_result: + out = self.linear(input) + out = ipex_comm.allgather(out, self.cols_per_rank, self.world_size) + else: + if self.world_size > 1: + input = input[ + ..., + self.cols_per_rank[self.rank] : self.cols_per_rank[self.rank + 1], + ] + out = self.linear(input) + if self.world_size > 1: + ipex_comm.allreduce_add(out) + + return out + + +def shard_mha_weights( + model, target_m, num_heads, num_kv_heads, head_dim, rank, world_size +): + if world_size == 1: + return + for name, sub_m in model.named_children(): + if isinstance(sub_m, target_m): + for l_name, l_sub_m in sub_m.named_children(): + if l_name in ["q_proj"]: + TPLinear = TensorParallelColumnLinear( + l_sub_m, + num_kv_heads, + num_heads, + head_dim, + rank, + world_size, + shard_by_head=True, + ) + # del sub_m.__dict__["_modules"][l_name] + setattr(sub_m, l_name, TPLinear.linear) + if l_name in ["k_proj", "v_proj"]: + TPLinear = TensorParallelColumnLinear( + l_sub_m, + num_kv_heads, + num_kv_heads, + head_dim, + rank, + world_size, + shard_by_head=True, + ) + # del sub_m.__dict__["_modules"][l_name] + setattr(sub_m, l_name, TPLinear.linear) + if l_name in ["out_proj", "o_proj"]: + TPLinear = TensorParallelRowLinear( + l_sub_m, + num_kv_heads, + num_heads, + head_dim, + rank, + world_size, + shard_by_head=True, + ) + # del sub_m.__dict__["_modules"][l_name] + setattr(sub_m, l_name, TPLinear) + + shard_mha_weights( + sub_m, target_m, num_heads, num_kv_heads, head_dim, rank, world_size + ) + + +def shard_mlp_weights( + model, target_m, num_heads, num_kv_heads, head_dim, rank, world_size +): + if world_size == 1: + return + for _, sub_m in model.named_children(): + if isinstance(sub_m, target_m): + for l_name, l_sub_m in sub_m.named_children(): + if l_name in ["gate_proj", "up_proj", "fc_in"]: + TPLinear = TensorParallelColumnLinear( + l_sub_m, + num_kv_heads, + num_heads, + head_dim, + rank, + world_size, + shard_by_head=False, + ) + setattr(sub_m, l_name, TPLinear.linear) + if l_name in ["down_proj", "fc_out"]: + TPLinear = TensorParallelRowLinear( + l_sub_m, + num_kv_heads, + num_kv_heads, + head_dim, + rank, + world_size, + shard_by_head=False, + ) + setattr(sub_m, l_name, TPLinear) + shard_mlp_weights( + sub_m, target_m, num_heads, num_kv_heads, head_dim, rank, world_size + ) + + +def shard_lm_head_weights( + model, supported_model_class, num_heads, num_kv_heads, head_dim, rank, world_size +): + if world_size == 1: + return + if not isinstance(model, supported_model_class): + return + for name, sub_m in model.named_children(): + lm_head_shard_policy = os.getenv("LM_HEAD_SHARD_POLICY", "row") + shard_by_col = lm_head_shard_policy == "col" + if name in ["lm_head"]: + TPLinear = TensorParallelLMhead( + sub_m, + num_kv_heads, + num_heads, + head_dim, + rank, + world_size, + shard_by_col=shard_by_col, + ) + setattr(model, name, TPLinear) + return + shard_lm_head_weights( + sub_m, + supported_model_class, + num_heads, + num_kv_heads, + head_dim, + rank, + world_size, + ) + + +def update_heads_info(_model, rank, world_size): + # update the head number of config after sharding + num_heads = _model.config.num_attention_heads + num_kv_heads = num_heads + head_dim = _model.config.hidden_size // num_heads + for name in ["num_key_value_heads"]: + if hasattr(_model.config, name): + num_kv_heads = getattr(_model.config, name) + group_size = num_heads // num_kv_heads + kv_head_per_rank = num_kv_heads // world_size + assert world_size <= num_kv_heads + kv_heads_range = [0] + for i in range(world_size - 1, -1, -1): + kv_heads_this_rank = kv_head_per_rank + if i < num_kv_heads % world_size: + kv_heads_this_rank += 1 + kv_heads_range.append(kv_heads_range[-1] + kv_heads_this_rank) + + def update(_model, group_size, kv_head_range): + for _, sub_m in _model.named_children(): + # update number of query heads + target_kv_head = kv_heads_range[rank + 1] - kv_heads_range[rank] + for name in ["num_attention_heads", "num_heads"]: + if hasattr(sub_m, "config") and hasattr(sub_m.config, name): + setattr(sub_m.config, name, group_size * target_kv_head) + if hasattr(sub_m, name): + setattr(sub_m, name, group_size * target_kv_head) + # update number of key/value heads + for name in ["num_key_value_heads", "num_key_value_heads"]: + if hasattr(sub_m, "config") and hasattr(sub_m.config, name): + setattr(sub_m.config, name, target_kv_head) + if hasattr(sub_m, name): + setattr(sub_m, name, target_kv_head) + # update hidden_size + for name in ["hidden_size"]: + if hasattr(sub_m, "config") and hasattr(sub_m.config, name): + setattr(sub_m.config, name, group_size * target_kv_head * head_dim) + if hasattr(sub_m, name): + setattr(sub_m, name, group_size * target_kv_head * head_dim) + update(sub_m, group_size, kv_head_range) + + update(_model, group_size, kv_heads_range) diff --git a/tests/cpu/test_ccl_primitive.py b/tests/cpu/test_ccl_primitive.py new file mode 100644 index 000000000..432e1521d --- /dev/null +++ b/tests/cpu/test_ccl_primitive.py @@ -0,0 +1,55 @@ +import unittest +import os +import torch +import intel_extension_for_pytorch as ipex + + +@unittest.skip("oneccl can't works in docker") +class CCLTester(unittest.TestCase): + def test_all_reduce_add(self): + mpi_world_size = int(os.environ.get("PMI_SIZE", -1)) + mpi_rank = int(os.environ.get("PMI_RANK", -1)) + ipex.enable_onednn_fusion(False) # just to workaround the flake8 + dtypes = [torch.float32, torch.float16, torch.bfloat16] + tensor_sizes = [4096, 4096 * 32, 8 * 1024 * 5120 * 4 * 2] + # Less than 8 * 1024 * 5120 * 4 use SHM, otherwise use ccl allreduce + # The above dispatch rule is transparent to users + for dtype in dtypes: + for tensor_size in tensor_sizes: + input_tensor = ( + torch.tensor([mpi_rank + 1.0]).to(dtype).repeat(tensor_size) + ) + target_tensor = ( + torch.tensor([float(mpi_world_size * (mpi_world_size + 1) / 2)]) + .to(dtype) + .repeat(tensor_size) + ) + ipex.cpu.comm.allreduce_add(input_tensor) + torch.allclose(input_tensor, target_tensor) + ipex.cpu.comm.barrier() + + self.assertEqual(mpi_world_size, ipex.cpu.comm.get_world_size()) + self.assertEqual(mpi_rank, ipex.cpu.comm.get_rank()) + + def test_allgather(self): + mpi_world_size = int(os.environ.get("PMI_SIZE", -1)) + mpi_rank = int(os.environ.get("PMI_RANK", -1)) + dtypes = [torch.float32, torch.float16, torch.bfloat16] + for dtype in dtypes: + for n in range(mpi_world_size + 1): + n = n + 1 + input = (torch.tensor([n * mpi_rank])).to(dtype) + col_per_rank = [] + for i in range(mpi_world_size + 1): + col_per_rank.append(i) + expected_output = [ + torch.tensor([i * n]).to(dtype) for i in range(mpi_world_size) + ] + expected_output = torch.cat(expected_output, dim=0) + + output = ipex.cpu.comm.allgather(input, col_per_rank, mpi_world_size) + torch.allclose(expected_output, output) + + +if __name__ == "__main__": + test = unittest.main() diff --git a/tests/cpu/test_ipex_tensor_parallel.py b/tests/cpu/test_ipex_tensor_parallel.py new file mode 100644 index 000000000..6b29d5265 --- /dev/null +++ b/tests/cpu/test_ipex_tensor_parallel.py @@ -0,0 +1,146 @@ +import unittest +import torch +import intel_extension_for_pytorch as ipex +import sys +import subprocess +import os +import copy +from intel_extension_for_pytorch.transformers import ( + shard_mha_weights, + shard_mlp_weights, + shard_lm_head_weights, + update_heads_info, + TensorParallelRowLinear, + TensorParallelLMhead, +) +from intel_extension_for_pytorch.cpu import comm as ipex_comm + +try: + import transformers + from transformers import AutoConfig +except ImportError: + subprocess.check_call( + [sys.executable, "-m", "pip", "install", "transformers==4.35.2"] + ) + import transformers + from transformers import AutoConfig + +from common_utils import TestCase + +torch.manual_seed(128) + +curpath = os.path.abspath(os.path.dirname(__file__)) + + +@unittest.skip("oneccl can't works in docker") +class TensorParallelTester(TestCase): + def _shard_model(self, model): + rank = ipex_comm.get_rank() + world_size = ipex_comm.get_world_size() + for supported_mha_class in [ + transformers.models.llama.modeling_llama.LlamaAttention, + transformers.models.gptj.modeling_gptj.GPTJAttention, + ]: + num_heads = model.config.num_attention_heads + num_kv_heads = num_heads + for name in ["num_key_value_heads"]: + if hasattr(model.config, name): + num_kv_heads = getattr(model.config, name) + head_dim = model.config.hidden_size // num_heads + shard_mha_weights( + model, + supported_mha_class, + num_heads, + num_kv_heads, + head_dim, + rank, + world_size, + ) + for supported_mlp_class in [ + transformers.models.llama.modeling_llama.LlamaMLP, + transformers.models.gptj.modeling_gptj.GPTJMLP, + ]: + shard_mlp_weights( + model, + supported_mlp_class, + num_heads, + num_kv_heads, + head_dim, + rank, + world_size, + ) + for supported_model_calss in [ + transformers.models.llama.modeling_llama.LlamaForCausalLM, + transformers.models.gptj.modeling_gptj.GPTJForCausalLM, + ]: + if isinstance(model, supported_model_calss): + shard_lm_head_weights( + model, + supported_model_calss, + num_heads, + num_kv_heads, + head_dim, + rank, + world_size, + ) + update_heads_info(model, rank, world_size) + return model + + def tensor_parallel_with_optimize_transformers(self, model): + ref_m = copy.deepcopy(model) + ipex_model = ipex.optimize_transformers(model) + input_ids = torch.ones(10).to(torch.long) + attention_mask = torch.ones(len(input_ids)) + position_ids = torch.arange(len(input_ids)) + input_dict = { + "input_ids": input_ids.unsqueeze(0), + "attention_mask": attention_mask.unsqueeze(0), + "use_cache": True, + } + input_dict["position_ids"] = position_ids.unsqueeze(0) + + for dtype in [torch.float32, torch.bfloat16]: + with torch.no_grad(), torch.cpu.amp.autocast( + enabled=True if dtype is torch.bfloat16 else False + ): + key_hf = ref_m(**input_dict) + key_ipex = ipex_model(**input_dict) + + self.assertEqual(key_hf[0], key_ipex[0], prec=0.1) + + def test_tensor_parallel_replace_check_gptj(self): + config = AutoConfig.from_pretrained( + f"{curpath}/hf_configs/gptj", return_dict=False + ) + model = transformers.models.gptj.modeling_gptj.GPTJForCausalLM(config).eval() + tp_model = self._shard_model(copy.deepcopy(model)) + self.assertTrue( + isinstance(tp_model.transformer.h[0].attn.out_proj, TensorParallelRowLinear) + ) + self.assertTrue( + isinstance(tp_model.transformer.h[0].mlp.fc_out, TensorParallelRowLinear) + ) + self.assertTrue(isinstance(tp_model.lm_head, TensorParallelLMhead)) + self.tensor_parallel_with_optimize_transformers(model) + + def test_tensor_parallel_replace_check_llama(self): + config = AutoConfig.from_pretrained( + f"{curpath}/hf_configs/llama", return_dict=False + ) + model = transformers.models.llama.modeling_llama.LlamaForCausalLM(config).eval() + tp_model = self._shard_model(copy.deepcopy(model)) + self.assertTrue( + isinstance( + tp_model.model.layers[0].self_attn.o_proj, TensorParallelRowLinear + ) + ) + self.assertTrue( + isinstance(tp_model.model.layers[0].mlp.down_proj, TensorParallelRowLinear) + ) + self.assertTrue(isinstance(tp_model.lm_head, TensorParallelLMhead)) + self.assertTrue(tp_model.lm_head, TensorParallelLMhead) + self.tensor_parallel_with_optimize_transformers(model) + + +if __name__ == "__main__": + test = unittest.main() diff --git a/third_party/oneCCL b/third_party/oneCCL new file mode 160000 index 000000000..796b7ef54 --- /dev/null +++ b/third_party/oneCCL @@ -0,0 +1 @@ +Subproject commit 796b7ef54cd77bd181fc0d8881385b81222a7a36