-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
26 changed files
with
1,717 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# - Try to find oneCCL | ||
# | ||
# The following are set after configuration is done: | ||
# ONECCL_FOUND : set to true if oneCCL is found. | ||
# ONECCL_INCLUDE_DIRS : path to oneCCL include dir. | ||
# ONECCL_LIBRARIES : list of libraries for oneCCL | ||
# | ||
# and the following imported targets: | ||
# | ||
# oneCCL | ||
|
||
IF (NOT ONECCL_FOUND) | ||
SET(ONECCL_FOUND OFF) | ||
SET(ONECCL_LIBRARIES) | ||
SET(ONECCL_INCLUDE_DIRS) | ||
|
||
SET(ONECCL_ROOT "${PROJECT_SOURCE_DIR}/third_party/oneCCL") | ||
|
||
IF(BUILD_NO_ONECCL_PACKAGE) | ||
ADD_SUBDIRECTORY(${ONECCL_ROOT} oneCCL EXCLUDE_FROM_ALL) | ||
ELSE() | ||
ADD_SUBDIRECTORY(${ONECCL_ROOT} build) | ||
ENDIF() | ||
|
||
IF(NOT TARGET ccl) | ||
MESSAGE(FATAL_ERROR "Failed to find oneCCL target") | ||
ENDIF() | ||
add_library(oneCCL ALIAS ccl) | ||
|
||
GET_TARGET_PROPERTY(INCLUDE_DIRS oneCCL INCLUDE_DIRECTORIES) | ||
SET(ONECCL_INCLUDE_DIRS ${INCLUDE_DIRS}) | ||
SET(ONECCL_LIBRARIES oneCCL) | ||
|
||
find_package_handle_standard_args(oneCCL FOUND_VAR ONECCL_FOUND REQUIRED_VARS ONECCL_LIBRARIES ONECCL_INCLUDE_DIRS) | ||
|
||
ENDIF(NOT ONECCL_FOUND) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#include "CollectiveCommunicationPrimitive.h" | ||
#include <ATen/FunctionalTensorWrapper.h> | ||
#include <torch/all.h> | ||
#include <torch/csrc/autograd/function.h> | ||
|
||
namespace torch_ipex { | ||
namespace cpu { | ||
|
||
IPEX_DEFINE_DISPATCH(all_reduce_add_kernel_stub); | ||
IPEX_DEFINE_DISPATCH(allgather_kernel_stub); | ||
|
||
at::Tensor all_reduce_add(at::Tensor t_in) { | ||
RECORD_FUNCTION("ipex::all_reduce_add", c10::ArrayRef<c10::IValue>({})); | ||
return all_reduce_add_kernel_stub(kCPU, t_in); | ||
} | ||
|
||
at::Tensor allgather( | ||
at::Tensor t_in, | ||
std::vector<int64_t> cols_per_rank, | ||
int64_t world_size) { | ||
RECORD_FUNCTION("ipex::allgather", c10::ArrayRef<c10::IValue>({})); | ||
return allgather_kernel_stub(kCPU, t_in, cols_per_rank, world_size); | ||
} | ||
|
||
} // namespace cpu | ||
} // namespace torch_ipex | ||
|
||
namespace { | ||
|
||
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) { | ||
m.def("all_reduce_add(Tensor(a!) t_in)-> (Tensor)"); | ||
m.impl( | ||
"all_reduce_add", c10::DispatchKey::CPU, torch_ipex::cpu::all_reduce_add); | ||
m.def("allgather(Tensor input, int[] output, int world_size) -> (Tensor)"); | ||
m.impl("allgather", c10::DispatchKey::CPU, torch_ipex::cpu::allgather); | ||
} | ||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#pragma once | ||
|
||
#include <ATen/ATen.h> | ||
#include <dyndisp/DispatchStub.h> | ||
|
||
namespace torch_ipex { | ||
namespace cpu { | ||
|
||
namespace { | ||
|
||
at::Tensor all_reduce_add(at::Tensor& t_in); | ||
at::Tensor allgather( | ||
at::Tensor t_in, | ||
std::vector<int64_t> cols_per_rank, | ||
int64_t world_size); | ||
int64_t get_world_size(const at::Tensor dummy_input); | ||
int64_t get_rank(const at::Tensor dummy_input); | ||
} // namespace | ||
|
||
using all_reduce_add_fn = at::Tensor (*)(at::Tensor& t_in); | ||
using allgather_fn = at::Tensor (*)( | ||
at::Tensor t_in, | ||
std::vector<int64_t> cols_per_rank, | ||
int64_t world_size); | ||
|
||
IPEX_DECLARE_DISPATCH(all_reduce_add_fn, all_reduce_add_kernel_stub); | ||
IPEX_DECLARE_DISPATCH(allgather_fn, allgather_kernel_stub); | ||
|
||
} // namespace cpu | ||
} // namespace torch_ipex |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
|
||
#include "ShmAllReduceAdd.h" | ||
#include <ATen/FunctionalTensorWrapper.h> | ||
#include <torch/all.h> | ||
#include <torch/csrc/autograd/function.h> | ||
|
||
namespace torch_ipex { | ||
namespace cpu { | ||
|
||
IPEX_DEFINE_DISPATCH(shm_all_reduce_add_kernel_stub); | ||
|
||
at::Tensor shm_all_reduce_add_forward_cpu( | ||
at::Tensor& t_in, | ||
at::Tensor& t_address, | ||
at::Tensor& t_state, | ||
at::Tensor& t_blockState, | ||
int64_t shm_block_size, | ||
int64_t rank, | ||
int64_t world_size) { | ||
return shm_all_reduce_add_kernel_stub( | ||
kCPU, | ||
t_in, | ||
t_address, | ||
t_state, | ||
t_blockState, | ||
shm_block_size, | ||
rank, | ||
world_size); | ||
} | ||
|
||
} // namespace cpu | ||
} // namespace torch_ipex |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#pragma once | ||
|
||
#include <ATen/ATen.h> | ||
#include <dyndisp/DispatchStub.h> | ||
|
||
namespace torch_ipex { | ||
namespace cpu { | ||
|
||
namespace { | ||
|
||
at::Tensor shm_all_reduce_add( | ||
at::Tensor& t_in, | ||
at::Tensor& t_address, | ||
at::Tensor& t_state, | ||
at::Tensor& t_blockState, | ||
int64_t shm_block_size, | ||
int64_t rank, | ||
int64_t world_size); | ||
} | ||
|
||
using shm_all_reduce_add_kernel_fn = at::Tensor (*)( | ||
at::Tensor& t_in, | ||
at::Tensor& t_address, | ||
at::Tensor& t_state, | ||
at::Tensor& t_blockState, | ||
int64_t shm_block_size, | ||
int64_t rank, | ||
int64_t world_size); | ||
|
||
IPEX_DECLARE_DISPATCH( | ||
shm_all_reduce_add_kernel_fn, | ||
shm_all_reduce_add_kernel_stub); | ||
|
||
} // namespace cpu | ||
} // namespace torch_ipex |
38 changes: 38 additions & 0 deletions
38
csrc/cpu/aten/kernels/CollectiveCommunicationPrimitiveKrnl.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/Tensor.h> | ||
#include <aten/CollectiveCommunicationPrimitive.h> | ||
#include <comm/messager.h> | ||
#include <torch/csrc/autograd/function.h> | ||
|
||
namespace torch_ipex { | ||
namespace cpu { | ||
|
||
namespace { | ||
at::Tensor all_reduce_add_kernel_impl(at::Tensor& t_in) { | ||
Messenger::getInstance().reduceAdd(t_in); | ||
return t_in; | ||
} | ||
|
||
at::Tensor allgather_kernel_impl( | ||
at::Tensor t_in, | ||
std::vector<int64_t> cols_per_rank, | ||
int64_t world_size) { | ||
std::vector<at::Tensor> output_tensors; | ||
auto shape = t_in.contiguous().sizes(); | ||
for (int64_t rank = 0; rank < world_size; rank++) { | ||
std::vector<int64_t> t_out_shape(shape.begin(), shape.end() - 1); | ||
t_out_shape.push_back(cols_per_rank[rank + 1] - cols_per_rank[rank]); | ||
output_tensors.push_back(at::empty(t_out_shape, t_in.options())); | ||
} | ||
|
||
return Messenger::getInstance().allgather(t_in, output_tensors); | ||
} | ||
|
||
} // anonymous namespace | ||
|
||
IPEX_REGISTER_DISPATCH(all_reduce_add_kernel_stub, &all_reduce_add_kernel_impl); | ||
|
||
IPEX_REGISTER_DISPATCH(allgather_kernel_stub, &allgather_kernel_impl); | ||
|
||
} // namespace cpu | ||
} // namespace torch_ipex |
Oops, something went wrong.