Skip to content

Commit

Permalink
IPEX Tensor Parallel (#2435)
Browse files Browse the repository at this point in the history
  • Loading branch information
liangan1 authored Mar 7, 2024
1 parent f4ee125 commit 4fa6445
Show file tree
Hide file tree
Showing 26 changed files with 1,717 additions and 3 deletions.
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
[submodule "third_party/libxsmm"]
path = third_party/libxsmm
url = https://github.com/libxsmm/libxsmm.git
[submodule "third_party/oneCCL"]
path = third_party/oneCCL
url = https://github.com/oneapi-src/oneCCL
[submodule "third_party/sleef"]
path = third_party/sleef
url = https://github.com/shibatch/sleef.git

36 changes: 36 additions & 0 deletions cmake/Modules/FindoneCCL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# - Try to find oneCCL
#
# The following are set after configuration is done:
# ONECCL_FOUND : set to true if oneCCL is found.
# ONECCL_INCLUDE_DIRS : path to oneCCL include dir.
# ONECCL_LIBRARIES : list of libraries for oneCCL
#
# and the following imported targets:
#
# oneCCL

IF (NOT ONECCL_FOUND)
SET(ONECCL_FOUND OFF)
SET(ONECCL_LIBRARIES)
SET(ONECCL_INCLUDE_DIRS)

SET(ONECCL_ROOT "${PROJECT_SOURCE_DIR}/third_party/oneCCL")

IF(BUILD_NO_ONECCL_PACKAGE)
ADD_SUBDIRECTORY(${ONECCL_ROOT} oneCCL EXCLUDE_FROM_ALL)
ELSE()
ADD_SUBDIRECTORY(${ONECCL_ROOT} build)
ENDIF()

IF(NOT TARGET ccl)
MESSAGE(FATAL_ERROR "Failed to find oneCCL target")
ENDIF()
add_library(oneCCL ALIAS ccl)

GET_TARGET_PROPERTY(INCLUDE_DIRS oneCCL INCLUDE_DIRECTORIES)
SET(ONECCL_INCLUDE_DIRS ${INCLUDE_DIRS})
SET(ONECCL_LIBRARIES oneCCL)

find_package_handle_standard_args(oneCCL FOUND_VAR ONECCL_FOUND REQUIRED_VARS ONECCL_LIBRARIES ONECCL_INCLUDE_DIRS)

ENDIF(NOT ONECCL_FOUND)
3 changes: 2 additions & 1 deletion cmake/cpu/BuildFlags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ if(env_cxx_standard GREATER -1)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_EXTENSIONS OFF)
#oneCCL build only support the gnu standard
set(CMAKE_CXX_EXTENSIONS ON)

if(MSVC)
set(CMAKE_COMPILE_WARNING_AS_ERROR OFF)
Expand Down
10 changes: 10 additions & 0 deletions cmake/cpu/Options.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ if(WIN32)
set(USE_LIBXSMM ON)
endif()

option(USE_CCL "Enable oneCCL in IPEX" ON)
option(USE_SHM "Enable shared memory communication in IPEX" ON)
if(WIN32)
set(USE_SHM OFF)
endif()
#set USE_SHM to OFF if USE_CCL is OFF


cmake_dependent_option(BUILD_STATIC_ONEMKL "Static link with oneMKL" OFF "BUILD_WITH_XPU" ON)

function (print_cpu_config_summary)
Expand Down Expand Up @@ -49,6 +57,8 @@ function (print_cpu_config_summary)
message(STATUS " IPEX_DISP_OP : ${IPEX_DISP_OP}")
message(STATUS " BUILD_XSMM_VIA_CMAKE : ${BUILD_LIBXSMM_VIA_CMAKE}")
message(STATUS " USE_LIBXSMM : ${USE_LIBXSMM}")
message(STATUS " USE_CCL : ${USE_CCL}")
message(STATUS " USE_SHM : ${USE_SHM}")
message(STATUS "")
message(STATUS "********************************")
endfunction()
24 changes: 23 additions & 1 deletion csrc/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ set(DNNL_ENABLE_PRIMITIVE_CACHE TRUE CACHE BOOL "" FORCE)
set(DNNL_LIBRARY_TYPE STATIC CACHE STRING "" FORCE)

#find_package(TorchCCL REQUIRED)
# Find OneCCL Lib
set(DEPENDS_LIB)
if(USE_CCL)
include(${IPEX_ROOT_DIR}/cmake/Modules/FindoneCCL.cmake)
# Find OneCCL Lib
link_directories(${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/oneCCL/deps/mpi/lib)
find_package(oneCCL REQUIRED)
list(APPEND DEPENDS_LIB oneCCL)
list(APPEND DEPENDS_LIB mpi)
endif()

# TODO: Once llga is merged into oneDNN, use oneDNN directly as the third_party of IPEX
# use the oneDNN in llga temporarily: third_party/llga/third_party/oneDNN
Expand All @@ -34,6 +44,14 @@ if(USE_LIBXSMM)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_LIBXSMM")
endif(USE_LIBXSMM)

if(USE_CCL)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_CCL")
endif(USE_CCL)

if(USE_SHM)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_SHM")
endif(USE_SHM)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_IPEX_MAIN_LIB")

# ---[ Main build
Expand Down Expand Up @@ -73,6 +91,7 @@ add_subdirectory(${IPEX_CPU_ROOT_DIR}/isa)
add_subdirectory(${IPEX_CPU_ROOT_DIR}/toolkit)
add_subdirectory(${IPEX_CPU_ROOT_DIR}/runtime)
add_subdirectory(${IPEX_CPU_ROOT_DIR}/utils)
add_subdirectory(${IPEX_CPU_ROOT_DIR}/comm)

add_subdirectory(${IPEX_CPU_ROOT_DIR}/jit)

Expand All @@ -84,7 +103,8 @@ if(USE_LIBXSMM)
endif(USE_LIBXSMM)

set(IPEX_CPU_CPP_SRCS ${IPEX_CPU_CPP_DYNDISP_SRCS} ${IPEX_CPU_CPP_ISA_SRCS_GEN} ${IPEX_CPU_CPP_UTILS_SRCS} ${IPEX_CPU_CPP_QUANTIZATION_SRCS} ${IPEX_CPU_CPP_JIT_SRCS} ${IPEX_JIT_COMMON_CPP_SRCS}
${IPEX_CPU_CPP_ISA_SRCS} ${IPEX_CPU_CPP_IDEEP_SRCS} ${IPEX_CPU_CPP_AUTOCAST_SRCS} ${IPEX_CPU_CPP_ATEN_SRCS} ${IPEX_CPU_CPP_RUNTIME_SRCS} ${IPEX_CPU_CPP_TOOLKIT_SRCS} ${IPEX_UTLIS_CPP_SRCS} ${IPEX_CPU_CPP_TPP_SRCS})
${IPEX_CPU_CPP_ISA_SRCS} ${IPEX_CPU_CPP_IDEEP_SRCS} ${IPEX_CPU_CPP_AUTOCAST_SRCS} ${IPEX_CPU_CPP_ATEN_SRCS} ${IPEX_CPU_CPP_RUNTIME_SRCS} ${IPEX_CPU_CPP_TOOLKIT_SRCS} ${IPEX_UTLIS_CPP_SRCS}
${IPEX_CPU_CPP_TPP_SRCS} ${IPEX_CPU_CPP_COMM_SRCS})

list(REMOVE_ITEM IPEX_CPU_CPP_SRCS ${IPEX_CPU_CPP_ISA_SRCS_ORIGIN})

Expand Down Expand Up @@ -123,6 +143,7 @@ target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${ONEDNN_GENERATED_INCLUDE}

target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${IPEX_CPU_CPP_THIRD_PARTY_ROOT}/ideep/include)
target_include_directories(${PLUGIN_NAME_CPU} PUBLIC ${PYTHON_INCLUDE_DIR})
target_link_libraries(${PLUGIN_NAME_CPU} PUBLIC ${DEPENDS_LIB})

include(${IPEX_ROOT_DIR}/cmake/ClangFormat.cmake)
if(CLANG_FORMAT)
Expand Down Expand Up @@ -221,6 +242,7 @@ if(BUILD_STRIPPED_BIN)
set_target_properties(${PLUGIN_NAME_CPU} PROPERTIES LINK_FLAGS_RELEASE -s)
endif()


install(TARGETS ${PLUGIN_NAME_CPU}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
Expand Down
37 changes: 37 additions & 0 deletions csrc/cpu/aten/CollectiveCommunicationPrimitive.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "CollectiveCommunicationPrimitive.h"
#include <ATen/FunctionalTensorWrapper.h>
#include <torch/all.h>
#include <torch/csrc/autograd/function.h>

namespace torch_ipex {
namespace cpu {

IPEX_DEFINE_DISPATCH(all_reduce_add_kernel_stub);
IPEX_DEFINE_DISPATCH(allgather_kernel_stub);

at::Tensor all_reduce_add(at::Tensor t_in) {
RECORD_FUNCTION("ipex::all_reduce_add", c10::ArrayRef<c10::IValue>({}));
return all_reduce_add_kernel_stub(kCPU, t_in);
}

at::Tensor allgather(
at::Tensor t_in,
std::vector<int64_t> cols_per_rank,
int64_t world_size) {
RECORD_FUNCTION("ipex::allgather", c10::ArrayRef<c10::IValue>({}));
return allgather_kernel_stub(kCPU, t_in, cols_per_rank, world_size);
}

} // namespace cpu
} // namespace torch_ipex

namespace {

TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
m.def("all_reduce_add(Tensor(a!) t_in)-> (Tensor)");
m.impl(
"all_reduce_add", c10::DispatchKey::CPU, torch_ipex::cpu::all_reduce_add);
m.def("allgather(Tensor input, int[] output, int world_size) -> (Tensor)");
m.impl("allgather", c10::DispatchKey::CPU, torch_ipex::cpu::allgather);
}
} // namespace
30 changes: 30 additions & 0 deletions csrc/cpu/aten/CollectiveCommunicationPrimitive.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <ATen/ATen.h>
#include <dyndisp/DispatchStub.h>

namespace torch_ipex {
namespace cpu {

namespace {

at::Tensor all_reduce_add(at::Tensor& t_in);
at::Tensor allgather(
at::Tensor t_in,
std::vector<int64_t> cols_per_rank,
int64_t world_size);
int64_t get_world_size(const at::Tensor dummy_input);
int64_t get_rank(const at::Tensor dummy_input);
} // namespace

using all_reduce_add_fn = at::Tensor (*)(at::Tensor& t_in);
using allgather_fn = at::Tensor (*)(
at::Tensor t_in,
std::vector<int64_t> cols_per_rank,
int64_t world_size);

IPEX_DECLARE_DISPATCH(all_reduce_add_fn, all_reduce_add_kernel_stub);
IPEX_DECLARE_DISPATCH(allgather_fn, allgather_kernel_stub);

} // namespace cpu
} // namespace torch_ipex
32 changes: 32 additions & 0 deletions csrc/cpu/aten/ShmAllReduceAdd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

#include "ShmAllReduceAdd.h"
#include <ATen/FunctionalTensorWrapper.h>
#include <torch/all.h>
#include <torch/csrc/autograd/function.h>

namespace torch_ipex {
namespace cpu {

IPEX_DEFINE_DISPATCH(shm_all_reduce_add_kernel_stub);

at::Tensor shm_all_reduce_add_forward_cpu(
at::Tensor& t_in,
at::Tensor& t_address,
at::Tensor& t_state,
at::Tensor& t_blockState,
int64_t shm_block_size,
int64_t rank,
int64_t world_size) {
return shm_all_reduce_add_kernel_stub(
kCPU,
t_in,
t_address,
t_state,
t_blockState,
shm_block_size,
rank,
world_size);
}

} // namespace cpu
} // namespace torch_ipex
35 changes: 35 additions & 0 deletions csrc/cpu/aten/ShmAllReduceAdd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include <ATen/ATen.h>
#include <dyndisp/DispatchStub.h>

namespace torch_ipex {
namespace cpu {

namespace {

at::Tensor shm_all_reduce_add(
at::Tensor& t_in,
at::Tensor& t_address,
at::Tensor& t_state,
at::Tensor& t_blockState,
int64_t shm_block_size,
int64_t rank,
int64_t world_size);
}

using shm_all_reduce_add_kernel_fn = at::Tensor (*)(
at::Tensor& t_in,
at::Tensor& t_address,
at::Tensor& t_state,
at::Tensor& t_blockState,
int64_t shm_block_size,
int64_t rank,
int64_t world_size);

IPEX_DECLARE_DISPATCH(
shm_all_reduce_add_kernel_fn,
shm_all_reduce_add_kernel_stub);

} // namespace cpu
} // namespace torch_ipex
38 changes: 38 additions & 0 deletions csrc/cpu/aten/kernels/CollectiveCommunicationPrimitiveKrnl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <aten/CollectiveCommunicationPrimitive.h>
#include <comm/messager.h>
#include <torch/csrc/autograd/function.h>

namespace torch_ipex {
namespace cpu {

namespace {
at::Tensor all_reduce_add_kernel_impl(at::Tensor& t_in) {
Messenger::getInstance().reduceAdd(t_in);
return t_in;
}

at::Tensor allgather_kernel_impl(
at::Tensor t_in,
std::vector<int64_t> cols_per_rank,
int64_t world_size) {
std::vector<at::Tensor> output_tensors;
auto shape = t_in.contiguous().sizes();
for (int64_t rank = 0; rank < world_size; rank++) {
std::vector<int64_t> t_out_shape(shape.begin(), shape.end() - 1);
t_out_shape.push_back(cols_per_rank[rank + 1] - cols_per_rank[rank]);
output_tensors.push_back(at::empty(t_out_shape, t_in.options()));
}

return Messenger::getInstance().allgather(t_in, output_tensors);
}

} // anonymous namespace

IPEX_REGISTER_DISPATCH(all_reduce_add_kernel_stub, &all_reduce_add_kernel_impl);

IPEX_REGISTER_DISPATCH(allgather_kernel_stub, &allgather_kernel_impl);

} // namespace cpu
} // namespace torch_ipex
Loading

0 comments on commit 4fa6445

Please sign in to comment.