diff --git a/BUILD.md b/BUILD.md
index ef2d1a2bda..c4d8b1b356 100644
--- a/BUILD.md
+++ b/BUILD.md
@@ -26,12 +26,12 @@
In addition to the libraries included with cudatoolkit 11.0+, there are some other dependencies below for building RAFT from source. Many of the dependencies are optional and depend only on the primitives being used. All of these can be installed with cmake or [rapids-cpm](https://github.com/rapidsai/rapids-cmake#cpm) and many of them can be installed with [conda](https://anaconda.org).
#### Required
-- [Thrust](https://github.com/NVIDIA/thrust) v1.15 / [CUB](https://github.com/NVIDIA/cub)
- [RMM](https://github.com/rapidsai/rmm) corresponding to RAFT version.
-- [mdspan](https://github.com/rapidsai/mdspan)
#### Optional
-- [cuCollections](https://github.com/NVIDIA/cuCollections) - Used in `raft::sparse::distance` API
+- [mdspan](https://github.com/rapidsai/mdspan) - On by default but can be disabled.
+- [Thrust](https://github.com/NVIDIA/thrust) v1.15 / [CUB](https://github.com/NVIDIA/cub) - On by default but can be disabled.
+- [cuCollections](https://github.com/NVIDIA/cuCollections) - Used in `raft::sparse::distance` API.
- [Libcu++](https://github.com/NVIDIA/libcudacxx) v1.7.0
- [FAISS](https://github.com/facebookresearch/faiss) v1.7.0 - Used in `raft::spatial::knn` API and needed to build tests.
- [NCCL](https://github.com/NVIDIA/nccl) - Used in `raft::comms` API and needed to build `Pyraft`
@@ -53,6 +53,11 @@ The following example will download the needed dependencies and install the RAFT
./build.sh libraft --install
```
+The `--minimal-deps` flag can be used to install the headers with minimal dependencies:
+```bash
+./build.sh libraft --install --minimal-deps
+```
+
### C++ Shared Libraries (optional)
For larger projects which make heavy use of the pairwise distances or nearest neighbors APIs, shared libraries can be built to speed up compile times. These shared libraries can also significantly improve re-compile times both while developing RAFT and developing against the APIs. Build all of the available shared libraries by passing `--compile-libs` flag to `build.sh`:
@@ -69,7 +74,14 @@ Add the `--install` flag to the above example to also install the shared librari
### Tests
-Compile the tests using the `tests` target in `build.sh`. By default, the shared libraries are assumed to be already built and on the library path. Add `--compile-libs` to also compile them.
+Compile the tests using the `tests` target in `build.sh`.
+
+```bash
+./build.sh libraft tests
+```
+
+Test compile times can be improved significantly by using the optional shared libraries. If installed, they will be used automatically when building the tests but `--compile-libs` can be used to add additional compilation units and compile them with the tests.
+
```bash
./build.sh libraft tests --compile-libs
```
@@ -110,11 +122,13 @@ RAFT's cmake has the following configurable flags available:.
| --- | --- | --- | --- |
| BUILD_TESTS | ON, OFF | ON | Compile Googletests |
| BUILD_BENCH | ON, OFF | ON | Compile benchmarks |
+| raft_FIND_COMPONENTS | nn distance | | Configures the optional components as a space-separated list |
| RAFT_COMPILE_LIBRARIES | ON, OFF | OFF | Compiles all `libraft` shared libraries (these are required for Googletests) |
-| RAFT_COMPILE_NN_LIBRARY | ON, OFF | ON | Compiles the `libraft-nn` shared library |
-| RAFT_COMPILE_DIST_LIBRARY | ON, OFF | ON | Compiles the `libraft-distance` shared library |
+| RAFT_COMPILE_NN_LIBRARY | ON, OFF | OFF | Compiles the `libraft-nn` shared library |
+| RAFT_COMPILE_DIST_LIBRARY | ON, OFF | OFF | Compiles the `libraft-distance` shared library |
| RAFT_ENABLE_NN_DEPENDENCIES | ON, OFF | OFF | Searches for dependencies of nearest neighbors API, such as FAISS, and compiles them if not found. Needed for `raft::spatial::knn` |
-| RAFT_ENABLE_cuco_DEPENDENCY | ON, OFF | ON | Enables the cuCollections dependency used by `raft::sparse::distance` |
+| RAFT_ENABLE_thrust_DEPENDENCY | ON, OFF | ON | Enables the Thrust dependency. This can be disabled when using many simple utilities or to override with a different Thrust version. |
+| RAFT_ENABLE_mdspan_DEPENDENCY | ON, OFF | ON | Enables the std::mdspan dependency. This can be disabled when using many simple utilities. |
| RAFT_ENABLE_nccl_DEPENDENCY | ON, OFF | OFF | Enables NCCL dependency used by `raft::comms` and needed to build `pyraft` |
| RAFT_ENABLE_ucx_DEPENDENCY | ON, OFF | OFF | Enables UCX dependency used by `raft::comms` and needed to build `pyraft` |
| RAFT_USE_FAISS_STATIC | ON, OFF | OFF | Statically link FAISS into `libraft-nn` |
@@ -212,7 +226,8 @@ set(RAFT_PINNED_TAG "branch-${RAFT_VERSION}")
function(find_and_configure_raft)
set(oneValueArgs VERSION FORK PINNED_TAG USE_FAISS_STATIC
COMPILE_LIBRARIES ENABLE_NN_DEPENDENCIES CLONE_ON_PIN
- USE_NN_LIBRARY USE_DISTANCE_LIBRARY)
+ USE_NN_LIBRARY USE_DISTANCE_LIBRARY
+ ENABLE_thrust_DEPENDENCY ENABLE_mdspan_DEPENDENCY)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )
@@ -256,6 +271,8 @@ function(find_and_configure_raft)
"RAFT_ENABLE_NN_DEPENDENCIES ${PKG_ENABLE_NN_DEPENDENCIES}"
"RAFT_USE_FAISS_STATIC ${PKG_USE_FAISS_STATIC}"
"RAFT_COMPILE_LIBRARIES ${PKG_COMPILE_LIBRARIES}"
+ "RAFT_ENABLE_thrust_DEPENDENCY ${PKG_ENABLE_thrust_DEPENDENCY}"
+ "RAFT_ENABLE_mdspan_DEPENDENCY ${PKG_ENABLE_mdspan_DEPENDENCY}"
)
endfunction()
@@ -272,11 +289,13 @@ find_and_configure_raft(VERSION ${RAFT_VERSION}.00
# even if it's already installed.
CLONE_ON_PIN ON
- COMPILE_LIBRARIES NO
- USE_NN_LIBRARY NO
- USE_DISTANCE_LIBRARY NO
- ENABLE_NN_DEPENDENCIES NO # This builds FAISS if not installed
- USE_FAISS_STATIC NO
+ COMPILE_LIBRARIES NO
+ USE_NN_LIBRARY NO
+ USE_DISTANCE_LIBRARY NO
+ ENABLE_NN_DEPENDENCIES NO # This builds FAISS if not installed
+ USE_FAISS_STATIC NO
+ ENABLE_thrust_DEPENDENCY YES
+ ENABLE_mdspan_DEPENDENCY YES
)
```
diff --git a/README.md b/README.md
index f73d474efc..c359a79e39 100755
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
#
RAFT: Reusable Accelerated Functions and Tools
-RAFT contains fundamental widely-used algorithms and primitives for data science, graph and machine learning. The algorithms are CUDA-accelerated and form building-blocks for rapidly composing analytics.
+RAFT contains fundamental widely-used algorithms and primitives for data science and machine learning. The algorithms are CUDA-accelerated and form building-blocks for rapidly composing analytics.
By taking a primitives-based approach to algorithm development, RAFT
- accelerates algorithm construction time
diff --git a/build.sh b/build.sh
index 0c3fbaccb6..568de2956d 100755
--- a/build.sh
+++ b/build.sh
@@ -18,7 +18,7 @@ ARGS=$*
# script, and that this script resides in the repo dir!
REPODIR=$(cd $(dirname $0); pwd)
-VALIDARGS="clean libraft pyraft pylibraft docs tests bench clean -v -g --install --compile-libs --compile-nn --compile-dist --allgpuarch --nvtx --show_depr_warn -h --buildfaiss"
+VALIDARGS="clean libraft pyraft pylibraft docs tests bench clean -v -g --install --compile-libs --compile-nn --compile-dist --allgpuarch --nvtx --show_depr_warn -h --buildfaiss --minimal-deps"
HELP="$0 [ ...] [ ...]
where is:
clean - remove all existing build artifacts and configuration (start over)
@@ -36,6 +36,8 @@ HELP="$0 [ ...] [ ...]
--compile-libs - compile shared libraries for all components
--compile-nn - compile shared library for nn component
--compile-dist - compile shared library for distance component
+ --minimal-deps - disables dependencies like thrust so they can be overridden.
+ can be useful for a pure header-only install
--allgpuarch - build for all supported GPU architectures
--buildfaiss - build faiss statically into raft
--install - install cmake targets
@@ -62,6 +64,9 @@ COMPILE_LIBRARIES=OFF
COMPILE_NN_LIBRARY=OFF
COMPILE_DIST_LIBRARY=OFF
ENABLE_NN_DEPENDENCIES=OFF
+
+ENABLE_thrust_DEPENDENCY=ON
+
ENABLE_ucx_DEPENDENCY=OFF
ENABLE_nccl_DEPENDENCY=OFF
@@ -105,6 +110,11 @@ fi
if hasArg --install; then
INSTALL_TARGET="install"
fi
+
+if hasArg --minimal-deps; then
+ ENABLE_thrust_DEPENDENCY=OFF
+fi
+
if hasArg -v; then
VERBOSE_FLAG="-v"
CMAKE_LOG_LEVEL="VERBOSE"
@@ -218,7 +228,8 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has
-DRAFT_COMPILE_DIST_LIBRARY=${COMPILE_DIST_LIBRARY} \
-DRAFT_USE_FAISS_STATIC=${BUILD_STATIC_FAISS} \
-DRAFT_ENABLE_nccl_DEPENDENCY=${ENABLE_nccl_DEPENDENCY} \
- -DRAFT_ENABLE_ucx_DEPENDENCY=${ENABLE_ucx_DEPENDENCY}
+ -DRAFT_ENABLE_ucx_DEPENDENCY=${ENABLE_ucx_DEPENDENCY} \
+ -DRAFT_ENABLE_thrust_DEPENDENCY=${ENABLE_thrust_DEPENDENCY}
if [[ ${CMAKE_TARGET} != "" ]]; then
echo "-- Compiling targets: ${CMAKE_TARGET}, verbose=${VERBOSE_FLAG}"
diff --git a/conda/recipes/libraft_distance/meta.yaml b/conda/recipes/libraft_distance/meta.yaml
index ad5a278466..9b78bd15f3 100644
--- a/conda/recipes/libraft_distance/meta.yaml
+++ b/conda/recipes/libraft_distance/meta.yaml
@@ -44,7 +44,6 @@ requirements:
- ucx-py {{ ucx_py_version }}
- ucx-proc=*=gpu
- gtest=1.10.0
- - gmock
- librmm {{ minor_version }}
run:
- libraft-headers {{ version }}
diff --git a/conda/recipes/libraft_headers/build.sh b/conda/recipes/libraft_headers/build.sh
index f239e545ef..d351b27577 100644
--- a/conda/recipes/libraft_headers/build.sh
+++ b/conda/recipes/libraft_headers/build.sh
@@ -1,4 +1,4 @@
#!/usr/bin/env bash
# Copyright (c) 2022, NVIDIA CORPORATION.
-./build.sh libraft --install -v --allgpuarch
+./build.sh libraft --install -v --allgpuarch
\ No newline at end of file
diff --git a/conda/recipes/libraft_headers/meta.yaml b/conda/recipes/libraft_headers/meta.yaml
index ed8dc4373e..fd95da66ee 100644
--- a/conda/recipes/libraft_headers/meta.yaml
+++ b/conda/recipes/libraft_headers/meta.yaml
@@ -43,7 +43,6 @@ requirements:
- ucx-py {{ ucx_py_version }}
- ucx-proc=*=gpu
- gtest=1.10.0
- - gmock
- librmm {{ minor_version}}
- libcusolver>=11.2.1
run:
diff --git a/conda/recipes/libraft_nn/meta.yaml b/conda/recipes/libraft_nn/meta.yaml
index 8cedb15d09..fa3392ddc8 100644
--- a/conda/recipes/libraft_nn/meta.yaml
+++ b/conda/recipes/libraft_nn/meta.yaml
@@ -44,7 +44,6 @@ requirements:
- faiss-proc=*=cuda
- libfaiss 1.7.0 *_cuda
- gtest=1.10.0
- - gmock
- librmm {{ minor_version }}
run:
- {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }}
diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt
index 35b066abc9..ab52b766e2 100644
--- a/cpp/CMakeLists.txt
+++ b/cpp/CMakeLists.txt
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#=============================================================================
+set(RAPIDS_VERSION "22.06")
+set(RAFT_VERSION "${RAPIDS_VERSION}.00")
cmake_minimum_required(VERSION 3.20.1 FATAL_ERROR)
file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-22.06/RAPIDS.cmake
@@ -26,7 +28,7 @@ include(rapids-find)
rapids_cuda_init_architectures(RAFT)
-project(RAFT VERSION 22.06.00 LANGUAGES CXX CUDA)
+project(RAFT VERSION ${RAFT_VERSION} LANGUAGES CXX CUDA)
# Needed because GoogleBenchmark changes the state of FindThreads.cmake, causing subsequent runs to
# have different values for the `Threads::Threads` target. Setting this flag ensures
@@ -55,16 +57,22 @@ option(CUDA_ENABLE_KERNELINFO "Enable kernel resource usage info" OFF)
option(CUDA_ENABLE_LINEINFO "Enable the -lineinfo option for nvcc (useful for cuda-memcheck / profiler)" OFF)
option(CUDA_STATIC_RUNTIME "Statically link the CUDA runtime" OFF)
option(DETECT_CONDA_ENV "Enable detection of conda environment for dependencies" ON)
-option(DISABLE_DEPRECATION_WARNINGS "Disable depreaction warnings " ON)
+option(DISABLE_DEPRECATION_WARNINGS "Disable deprecaction warnings " ON)
option(DISABLE_OPENMP "Disable OpenMP" OFF)
option(NVTX "Enable nvtx markers" OFF)
-option(RAFT_COMPILE_LIBRARIES "Enable building raft shared library instantiations" ON)
+option(RAFT_COMPILE_LIBRARIES "Enable building raft shared library instantiations" ${BUILD_TESTS})
option(RAFT_COMPILE_NN_LIBRARY "Enable building raft nearest neighbors shared library instantiations" OFF)
option(RAFT_COMPILE_DIST_LIBRARY "Enable building raft distant shared library instantiations" OFF)
option(RAFT_ENABLE_NN_DEPENDENCIES "Search for raft::nn dependencies like faiss" ${RAFT_COMPILE_LIBRARIES})
-option(RAFT_ENABLE_cuco_DEPENDENCY "Enable cuCollections dependency" ON)
+option(RAFT_ENABLE_mdspan_DEPENDENCY "Enable mdspan dependency" ON)
+option(RAFT_ENABLE_thrust_DEPENDENCY "Enable Thrust dependency" ON)
+
+if(BUILD_TESTS AND NOT RAFT_ENABLE_thrust_DEPENDENCY)
+ message(VERBOSE "RAFT: BUILD_TESTS is enabled, overriding RAFT_ENABLE_thrust_DEPENDENCY")
+ set(RAFT_ENABLE_thrust_DEPENDENCY ON)
+endif()
# Currently, UCX and NCCL are only needed to build Pyraft and so a simple find_package() is sufficient
option(RAFT_ENABLE_nccl_DEPENDENCY "Enable NCCL dependency" OFF)
@@ -75,6 +83,7 @@ option(RAFT_EXCLUDE_FAISS_FROM_ALL "Exclude FAISS targets from RAFT's 'all' targ
include(CMakeDependentOption)
cmake_dependent_option(RAFT_USE_FAISS_STATIC "Build and statically link the FAISS library for nearest neighbors search on GPU" ON RAFT_COMPILE_LIBRARIES OFF)
+message(VERBOSE "RAFT: Building optional components: ${raft_FIND_COMPONENTS}")
message(VERBOSE "RAFT: Build RAFT unit-tests: ${BUILD_TESTS}")
message(VERBOSE "RAFT: Building raft C++ benchmarks: ${BUILD_BENCH}")
message(VERBOSE "RAFT: Enable detection of conda environment for dependencies: ${DETECT_CONDA_ENV}")
@@ -123,6 +132,10 @@ include(cmake/modules/ConfigureCUDA.cmake)
##############################################################################
# - Requirements -------------------------------------------------------------
+if(distance IN_LIST raft_FIND_COMPONENTS OR RAFT_COMPILE_LIBRARIES OR RAFT_COMPILE_DIST_LIBRARY)
+ set(RAFT_ENABLE_cuco_DEPENDENCY ON)
+endif()
+
# add third party dependencies using CPM
rapids_cpm_init()
@@ -151,8 +164,10 @@ target_include_directories(raft INTERFACE
"$"
"$")
+# Keep RAFT as lightweight as possible.
+# Only CUDA libs, rmm, and mdspan should
+# be used in global target.
target_link_libraries(raft INTERFACE
- raft::Thrust
$<$:CUDA::nvToolsExt>
CUDA::cublas
CUDA::curand
@@ -160,8 +175,9 @@ target_link_libraries(raft INTERFACE
CUDA::cudart
CUDA::cusparse
rmm::rmm
- $<$:cuco::cuco>
- std::mdspan)
+ $<$:raft::Thrust>
+ $<$:std::mdspan>
+)
target_compile_definitions(raft INTERFACE $<$:NVTX_ENABLED>)
target_compile_features(raft INTERFACE cxx_std_17 $)
@@ -248,6 +264,7 @@ endif()
target_link_libraries(raft_distance INTERFACE
raft::raft
+ $<$:cuco::cuco>
$
$
)
@@ -301,6 +318,7 @@ endif()
target_link_libraries(raft_nn INTERFACE
raft::raft
+ $<$:faiss::faiss>
$
$)
@@ -341,6 +359,9 @@ install(DIRECTORY include/raft
install(FILES include/raft.hpp
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/raft)
+install(FILES ${CMAKE_CURRENT_BINARY_DIR}/include/raft/version_config.hpp
+ DESTINATION include/raft)
+
##############################################################################
# - install export -----------------------------------------------------------
set(doc_string
@@ -348,7 +369,7 @@ set(doc_string
Provide targets for the RAFT: Reusable Accelerated Functions and Tools
RAFT contains fundamental widely-used algorithms and primitives
-for data science, graph, and ml.
+for data science and machine learning.
Optional Components:
- nn
@@ -361,13 +382,18 @@ Imported Targets:
]=])
-set(code_string
-[=[
-
-if(NOT TARGET raft::Thrust)
- thrust_create_target(raft::Thrust FROM_OPTIONS)
+set(code_string )
+if(RAFT_ENABLE_thrust_DEPENDENCY)
+ string(APPEND code_string
+ [=[
+ if(NOT TARGET raft::Thrust)
+ thrust_create_target(raft::Thrust FROM_OPTIONS)
+ endif()
+ ]=])
endif()
+string(APPEND code_string
+[=[
if(distance IN_LIST raft_FIND_COMPONENTS)
enable_language(CUDA)
endif()
@@ -381,8 +407,7 @@ if(nn IN_LIST raft_FIND_COMPONENTS)
add_library(faiss ALIAS faiss::faiss)
endif()
endif()
-]=]
- )
+]=])
# Use `rapids_export` for 22.04 as it will have COMPONENT support
include(cmake/modules/raft_export.cmake)
diff --git a/cpp/cmake/thirdparty/get_cuco.cmake b/cpp/cmake/thirdparty/get_cuco.cmake
index a8951a3ee9..c35db4c962 100644
--- a/cpp/cmake/thirdparty/get_cuco.cmake
+++ b/cpp/cmake/thirdparty/get_cuco.cmake
@@ -16,21 +16,20 @@
function(find_and_configure_cuco VERSION)
- if(RAFT_ENABLE_cuco_DEPENDENCY)
- rapids_cpm_find(cuco ${VERSION}
- GLOBAL_TARGETS cuco::cuco
- BUILD_EXPORT_SET raft-exports
- INSTALL_EXPORT_SET raft-exports
- CPM_ARGS
- GIT_REPOSITORY https://github.com/NVIDIA/cuCollections.git
- GIT_TAG fb58a38701f1c24ecfe07d8f1f208bbe80930da5
- OPTIONS "BUILD_TESTS OFF"
- "BUILD_BENCHMARKS OFF"
- "BUILD_EXAMPLES OFF"
- )
- endif()
-
+ rapids_cpm_find(cuco ${VERSION}
+ GLOBAL_TARGETS cuco::cuco
+ BUILD_EXPORT_SET raft-distance-exports
+ INSTALL_EXPORT_SET raft-distance-exports
+ CPM_ARGS
+ GIT_REPOSITORY https://github.com/NVIDIA/cuCollections.git
+ GIT_TAG 6ec8b6dcdeceea07ab4456d32461a05c18864411
+ OPTIONS "BUILD_TESTS OFF"
+ "BUILD_BENCHMARKS OFF"
+ "BUILD_EXAMPLES OFF"
+ )
endfunction()
-# cuCollections doesn't have a version yet
-find_and_configure_cuco(0.0.1)
+if(RAFT_ENABLE_cuco_DEPENDENCY)
+ # cuCollections doesn't have a version yet
+ find_and_configure_cuco(0.0.1)
+endif()
diff --git a/cpp/cmake/thirdparty/get_libcudacxx.cmake b/cpp/cmake/thirdparty/get_libcudacxx.cmake
index a018341b24..92d8e57de9 100644
--- a/cpp/cmake/thirdparty/get_libcudacxx.cmake
+++ b/cpp/cmake/thirdparty/get_libcudacxx.cmake
@@ -14,11 +14,13 @@
# This function finds libcudacxx and sets any additional necessary environment variables.
function(find_and_configure_libcudacxx)
+
include(${rapids-cmake-dir}/cpm/libcudacxx.cmake)
rapids_cpm_libcudacxx(BUILD_EXPORT_SET raft-exports
INSTALL_EXPORT_SET raft-exports)
-
endfunction()
-find_and_configure_libcudacxx()
+if(RAFT_ENABLE_cuco_DEPENDENCY)
+ find_and_configure_libcudacxx()
+endif()
\ No newline at end of file
diff --git a/cpp/cmake/thirdparty/get_mdspan.cmake b/cpp/cmake/thirdparty/get_mdspan.cmake
index 12ac7ab0fd..5af3c4f31e 100644
--- a/cpp/cmake/thirdparty/get_mdspan.cmake
+++ b/cpp/cmake/thirdparty/get_mdspan.cmake
@@ -13,17 +13,19 @@
# =============================================================================
function(find_and_configure_mdspan VERSION)
- rapids_cpm_find(
- mdspan ${VERSION}
- GLOBAL_TARGETS std::mdspan
- BUILD_EXPORT_SET raft-exports
- INSTALL_EXPORT_SET raft-exports
- CPM_ARGS
- GIT_REPOSITORY https://github.com/rapidsai/mdspan.git
- GIT_TAG b3042485358d2ee168ae2b486c98c2c61ec5aec1
- OPTIONS "MDSPAN_ENABLE_CUDA ON"
- "MDSPAN_CXX_STANDARD ON"
- )
+ rapids_cpm_find(
+ mdspan ${VERSION}
+ GLOBAL_TARGETS std::mdspan
+ BUILD_EXPORT_SET raft-exports
+ INSTALL_EXPORT_SET raft-exports
+ CPM_ARGS
+ GIT_REPOSITORY https://github.com/rapidsai/mdspan.git
+ GIT_TAG b3042485358d2ee168ae2b486c98c2c61ec5aec1
+ OPTIONS "MDSPAN_ENABLE_CUDA ON"
+ "MDSPAN_CXX_STANDARD ON"
+ )
endfunction()
-find_and_configure_mdspan(0.2.0)
+if(RAFT_ENABLE_mdspan_DEPENDENCY)
+ find_and_configure_mdspan(0.2.0)
+endif()
diff --git a/cpp/cmake/thirdparty/get_thrust.cmake b/cpp/cmake/thirdparty/get_thrust.cmake
index 03dfecde6a..12360b9482 100644
--- a/cpp/cmake/thirdparty/get_thrust.cmake
+++ b/cpp/cmake/thirdparty/get_thrust.cmake
@@ -14,11 +14,13 @@
# Use CPM to find or clone thrust
function(find_and_configure_thrust)
- include(${rapids-cmake-dir}/cpm/thrust.cmake)
+ include(${rapids-cmake-dir}/cpm/thrust.cmake)
- rapids_cpm_thrust( NAMESPACE raft
- BUILD_EXPORT_SET raft-exports
- INSTALL_EXPORT_SET raft-exports)
+ rapids_cpm_thrust( NAMESPACE raft
+ BUILD_EXPORT_SET raft-exports
+ INSTALL_EXPORT_SET raft-exports)
endfunction()
-find_and_configure_thrust()
+if(RAFT_ENABLE_thrust_DEPENDENCY)
+ find_and_configure_thrust()
+endif()
diff --git a/cpp/include/raft/common/logger.hpp b/cpp/include/raft/common/logger.hpp
index 9066e103d0..77483e577d 100644
--- a/cpp/include/raft/common/logger.hpp
+++ b/cpp/include/raft/common/logger.hpp
@@ -13,286 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#pragma once
-
-#include
-
-#include
-
-#include
-#include
-#include
-#include
-#include
-
-#include
-
-#define SPDLOG_HEADER_ONLY
-#include
-#include // NOLINT
-#include // NOLINT
/**
- * @defgroup logging levels used in raft
- *
- * @note exactly match the corresponding ones (but reverse in terms of value)
- * in spdlog for wrapping purposes
- *
- * @{
+ * This file is deprecated and will be removed in release 22.08.
+ * Please use the include/core/logger.hpp instead.
*/
-#define RAFT_LEVEL_TRACE 6
-#define RAFT_LEVEL_DEBUG 5
-#define RAFT_LEVEL_INFO 4
-#define RAFT_LEVEL_WARN 3
-#define RAFT_LEVEL_ERROR 2
-#define RAFT_LEVEL_CRITICAL 1
-#define RAFT_LEVEL_OFF 0
-/** @} */
-
-#if !defined(RAFT_ACTIVE_LEVEL)
-#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_DEBUG
-#endif
-namespace raft {
-
-static const std::string RAFT_NAME = "raft";
-static const std::string default_log_pattern("[%L] [%H:%M:%S.%f] %v");
-
-/**
- * @defgroup CStringFormat Expand a C-style format string
- *
- * @brief Expands C-style formatted string into std::string
- *
- * @param[in] fmt format string
- * @param[in] vl respective values for each of format modifiers in the string
- *
- * @return the expanded `std::string`
- *
- * @{
- */
-std::string format(const char* fmt, va_list& vl)
-{
- char buf[4096];
- vsnprintf(buf, sizeof(buf), fmt, vl);
- return std::string(buf);
-}
-
-std::string format(const char* fmt, ...)
-{
- va_list vl;
- va_start(vl, fmt);
- std::string str = format(fmt, vl);
- va_end(vl);
- return str;
-}
-/** @} */
-
-int convert_level_to_spdlog(int level)
-{
- level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level));
- return RAFT_LEVEL_TRACE - level;
-}
-
-/**
- * @brief The main Logging class for raft library.
- *
- * This class acts as a thin wrapper over the underlying `spdlog` interface. The
- * design is done in this way in order to avoid us having to also ship `spdlog`
- * header files in our installation.
- *
- * @todo This currently only supports logging to stdout. Need to add support in
- * future to add custom loggers as well [Issue #2046]
- */
-class logger {
- public:
- // @todo setting the logger once per process with
- logger(std::string const& name_ = "")
- : sink{std::make_shared()},
- spdlogger{std::make_shared(name_, sink)},
- cur_pattern()
- {
- set_pattern(default_log_pattern);
- set_level(RAFT_LEVEL_INFO);
- }
- /**
- * @brief Singleton method to get the underlying logger object
- *
- * @return the singleton logger object
- */
- static logger& get(std::string const& name = "")
- {
- if (log_map.find(name) == log_map.end()) {
- log_map[name] = std::make_shared(name);
- }
- return *log_map[name];
- }
-
- /**
- * @brief Set the logging level.
- *
- * Only messages with level equal or above this will be printed
- *
- * @param[in] level logging level
- *
- * @note The log level will actually be set only if the input is within the
- * range [RAFT_LEVEL_TRACE, RAFT_LEVEL_OFF]. If it is not, then it'll
- * be ignored. See documentation of decisiontree for how this gets used
- */
- void set_level(int level)
- {
- level = convert_level_to_spdlog(level);
- spdlogger->set_level(static_cast(level));
- }
-
- /**
- * @brief Set the logging pattern
- *
- * @param[in] pattern the pattern to be set. Refer this link
- * https://github.com/gabime/spdlog/wiki/3.-Custom-formatting
- * to know the right syntax of this pattern
- */
- void set_pattern(const std::string& pattern)
- {
- cur_pattern = pattern;
- spdlogger->set_pattern(pattern);
- }
-
- /**
- * @brief Register a callback function to be run in place of usual log call
- *
- * @param[in] callback the function to be run on all logged messages
- */
- void set_callback(void (*callback)(int lvl, const char* msg)) { sink->set_callback(callback); }
-
- /**
- * @brief Register a flush function compatible with the registered callback
- *
- * @param[in] flush the function to use when flushing logs
- */
- void set_flush(void (*flush)()) { sink->set_flush(flush); }
-
- /**
- * @brief Tells whether messages will be logged for the given log level
- *
- * @param[in] level log level to be checked for
- * @return true if messages will be logged for this level, else false
- */
- bool should_log_for(int level) const
- {
- level = convert_level_to_spdlog(level);
- auto level_e = static_cast(level);
- return spdlogger->should_log(level_e);
- }
-
- /**
- * @brief Query for the current log level
- *
- * @return the current log level
- */
- int get_level() const
- {
- auto level_e = spdlogger->level();
- return RAFT_LEVEL_TRACE - static_cast(level_e);
- }
-
- /**
- * @brief Get the current logging pattern
- * @return the pattern
- */
- std::string get_pattern() const { return cur_pattern; }
-
- /**
- * @brief Main logging method
- *
- * @param[in] level logging level of this message
- * @param[in] fmt C-like format string, followed by respective params
- */
- void log(int level, const char* fmt, ...)
- {
- level = convert_level_to_spdlog(level);
- auto level_e = static_cast(level);
- // explicit check to make sure that we only expand messages when required
- if (spdlogger->should_log(level_e)) {
- va_list vl;
- va_start(vl, fmt);
- auto msg = format(fmt, vl);
- va_end(vl);
- spdlogger->log(level_e, msg);
- }
- }
-
- /**
- * @brief Flush logs by calling flush on underlying logger
- */
- void flush() { spdlogger->flush(); }
-
- ~logger() {}
-
- private:
- logger();
-
- static inline std::unordered_map> log_map;
- std::shared_ptr sink;
- std::shared_ptr spdlogger;
- std::string cur_pattern;
- int cur_level;
-}; // class logger
-
-}; // namespace raft
-
-/**
- * @defgroup loggerMacros Helper macros for dealing with logging
- * @{
- */
-#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE)
-#define RAFT_LOG_TRACE(fmt, ...) \
- do { \
- std::stringstream ss; \
- ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \
- ss << raft::detail::format(fmt, ##__VA_ARGS__); \
- raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \
- } while (0)
-#else
-#define RAFT_LOG_TRACE(fmt, ...) void(0)
-#endif
-
-#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG)
-#define RAFT_LOG_DEBUG(fmt, ...) \
- do { \
- std::stringstream ss; \
- ss << raft::format("%s:%d ", __FILE__, __LINE__); \
- ss << raft::format(fmt, ##__VA_ARGS__); \
- raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \
- } while (0)
-#else
-#define RAFT_LOG_DEBUG(fmt, ...) void(0)
-#endif
-
-#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_INFO)
-#define RAFT_LOG_INFO(fmt, ...) \
- raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_INFO, fmt, ##__VA_ARGS__)
-#else
-#define RAFT_LOG_INFO(fmt, ...) void(0)
-#endif
-
-#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_WARN)
-#define RAFT_LOG_WARN(fmt, ...) \
- raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_WARN, fmt, ##__VA_ARGS__)
-#else
-#define RAFT_LOG_WARN(fmt, ...) void(0)
-#endif
-
-#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_ERROR)
-#define RAFT_LOG_ERROR(fmt, ...) \
- raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_ERROR, fmt, ##__VA_ARGS__)
-#else
-#define RAFT_LOG_ERROR(fmt, ...) void(0)
-#endif
+#pragma once
-#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_CRITICAL)
-#define RAFT_LOG_CRITICAL(fmt, ...) \
- raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_CRITICAL, fmt, ##__VA_ARGS__)
-#else
-#define RAFT_LOG_CRITICAL(fmt, ...) void(0)
-#endif
-/** @} */
+#include
\ No newline at end of file
diff --git a/cpp/include/raft/common/nvtx.hpp b/cpp/include/raft/common/nvtx.hpp
index 918d5e10d8..385bc544b0 100644
--- a/cpp/include/raft/common/nvtx.hpp
+++ b/cpp/include/raft/common/nvtx.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, NVIDIA CORPORATION.
+ * Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,142 +14,11 @@
* limitations under the License.
*/
-#pragma once
-
-#include "detail/nvtx.hpp"
-#include
-
-/**
- * \section Usage
- *
- * To add NVTX ranges to your code, use the `nvtx::range` RAII object. A
- * range begins when the object is created, and ends when the object is
- * destroyed.
- *
- * The example below creates nested NVTX ranges. The range `fun_scope` spans
- * the whole function, while the range `epoch_scope` spans an iteration
- * (and appears 5 times in the timeline).
- * \code{.cpp}
- * #include
- * void some_function(int k){
- * // Begins a NVTX range with the messsage "some_function_{k}"
- * // The range ends when some_function() returns
- * common::nvtx::range fun_scope( r{"some_function_%d", k};
- *
- * for(int i = 0; i < 5; i++){
- * common::nvtx::range epoch_scope{"epoch-%d", i};
- * // some logic inside the loop
- * }
- * }
- * \endcode
- *
- * \section Domains
- *
- * All NVTX ranges are assigned to domains. A domain defines a named timeline in
- * the Nsight Systems view. By default, we put all ranges into a domain `domain::app`
- * named "application". This is controlled by the template parameter `Domain`.
- *
- * The example below defines a domain and uses it in a function.
- * \code{.cpp}
- * #include
- *
- * struct my_app_domain {
- * static constexpr char const* name{"my application"};
- * }
- *
- * void some_function(int k){
- * // This NVTX range appears in the timeline named "my application" in Nsight Systems.
- * common::nvtx::range fun_scope( r{"some_function_%d", k};
- * // some logic inside the loop
- * }
- * \endcode
- */
-namespace raft::common::nvtx {
-
-namespace domain {
-
-/** @brief The default NVTX domain. */
-struct app {
- static constexpr char const* name{"application"};
-};
-
-/** @brief This NVTX domain is supposed to be used within raft. */
-struct raft {
- static constexpr char const* name{"raft"};
-};
-
-} // namespace domain
-
-/**
- * @brief Push a named NVTX range.
- *
- * @tparam Domain optional struct that defines the NVTX domain message;
- * You can create a new domain with a custom message as follows:
- * \code{.cpp}
- * struct custom_domain { static constexpr char const* name{"custom message"}; }
- * \endcode
- * NB: make sure to use the same domain for `push_range` and `pop_range`.
- * @param format range name format (accepts printf-style arguments)
- * @param args the arguments for the printf-style formatting
- */
-template
-inline void push_range(const char* format, Args... args)
-{
- detail::push_range(format, args...);
-}
-
/**
- * @brief Pop the latest range.
- *
- * @tparam Domain optional struct that defines the NVTX domain message;
- * You can create a new domain with a custom message as follows:
- * \code{.cpp}
- * struct custom_domain { static constexpr char const* name{"custom message"}; }
- * \endcode
- * NB: make sure to use the same domain for `push_range` and `pop_range`.
+ * This file is deprecated and will be removed in release 22.08.
+ * Please use the include/core/nvtx.hpp instead.
*/
-template
-inline void pop_range()
-{
- detail::pop_range();
-}
-/**
- * @brief Push a named NVTX range that would be popped at the end of the object lifetime.
- *
- * Refer to \ref Usage for the usage examples.
- *
- * @tparam Domain optional struct that defines the NVTX domain message;
- * You can create a new domain with a custom message as follows:
- * \code{.cpp}
- * struct custom_domain { static constexpr char const* name{"custom message"}; }
- * \endcode
- */
-template
-class range {
- public:
- /**
- * Push a named NVTX range.
- * At the end of the object lifetime, pop the range back.
- *
- * @param format range name format (accepts printf-style arguments)
- * @param args the arguments for the printf-style formatting
- */
- template
- explicit range(const char* format, Args... args)
- {
- push_range(format, args...);
- }
-
- ~range() { pop_range(); }
-
- /* This object is not meant to be touched. */
- range(const range&) = delete;
- range(range&&) = delete;
- auto operator=(const range&) -> range& = delete;
- auto operator=(range&&) -> range& = delete;
- static auto operator new(std::size_t) -> void* = delete;
- static auto operator new[](std::size_t) -> void* = delete;
-};
+#pragma once
-} // namespace raft::common::nvtx
+#include
\ No newline at end of file
diff --git a/cpp/include/raft/comms/comms.hpp b/cpp/include/raft/comms/comms.hpp
index 9fb2b5a2c6..2ab0f053fc 100644
--- a/cpp/include/raft/comms/comms.hpp
+++ b/cpp/include/raft/comms/comms.hpp
@@ -13,631 +13,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
/**
* This file is deprecated and will be removed in release 22.06.
- * Please use raft_runtime/comms.hpp instead.
+ * Please use core/comms.hpp instead.
*/
-#ifndef __RAFT_RT_COMMS_H
-#define __RAFT_RT_COMMS_H
-
#pragma once
-#include
-#include
-#include
-
-namespace raft {
-namespace comms {
-
-typedef unsigned int request_t;
-enum class datatype_t { CHAR, UINT8, INT32, UINT32, INT64, UINT64, FLOAT32, FLOAT64 };
-enum class op_t { SUM, PROD, MIN, MAX };
-
-/**
- * The resulting status of distributed stream synchronization
- */
-enum class status_t {
- SUCCESS, // Synchronization successful
- ERROR, // An error occured querying sync status
- ABORT // A failure occurred in sync, queued operations aborted
-};
-
-template
-constexpr datatype_t
-
-get_type();
-
-template <>
-constexpr datatype_t
-
-get_type()
-{
- return datatype_t::CHAR;
-}
-
-template <>
-constexpr datatype_t
-
-get_type()
-{
- return datatype_t::UINT8;
-}
-
-template <>
-constexpr datatype_t
-
-get_type()
-{
- return datatype_t::INT32;
-}
-
-template <>
-constexpr datatype_t
-
-get_type()
-{
- return datatype_t::UINT32;
-}
-
-template <>
-constexpr datatype_t
-
-get_type()
-{
- return datatype_t::INT64;
-}
-
-template <>
-constexpr datatype_t
-
-get_type()
-{
- return datatype_t::UINT64;
-}
-
-template <>
-constexpr datatype_t
-
-get_type()
-{
- return datatype_t::FLOAT32;
-}
-
-template <>
-constexpr datatype_t
-
-get_type()
-{
- return datatype_t::FLOAT64;
-}
-
-class comms_iface {
- public:
- virtual ~comms_iface() {}
-
- virtual int get_size() const = 0;
-
- virtual int get_rank() const = 0;
-
- virtual std::unique_ptr comm_split(int color, int key) const = 0;
-
- virtual void barrier() const = 0;
-
- virtual status_t sync_stream(cudaStream_t stream) const = 0;
-
- virtual void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const = 0;
-
- virtual void irecv(void* buf, size_t size, int source, int tag, request_t* request) const = 0;
-
- virtual void waitall(int count, request_t array_of_requests[]) const = 0;
-
- virtual void allreduce(const void* sendbuff,
- void* recvbuff,
- size_t count,
- datatype_t datatype,
- op_t op,
- cudaStream_t stream) const = 0;
-
- virtual void bcast(
- void* buff, size_t count, datatype_t datatype, int root, cudaStream_t stream) const = 0;
-
- virtual void bcast(const void* sendbuff,
- void* recvbuff,
- size_t count,
- datatype_t datatype,
- int root,
- cudaStream_t stream) const = 0;
-
- virtual void reduce(const void* sendbuff,
- void* recvbuff,
- size_t count,
- datatype_t datatype,
- op_t op,
- int root,
- cudaStream_t stream) const = 0;
-
- virtual void allgather(const void* sendbuff,
- void* recvbuff,
- size_t sendcount,
- datatype_t datatype,
- cudaStream_t stream) const = 0;
-
- virtual void allgatherv(const void* sendbuf,
- void* recvbuf,
- const size_t* recvcounts,
- const size_t* displs,
- datatype_t datatype,
- cudaStream_t stream) const = 0;
-
- virtual void gather(const void* sendbuff,
- void* recvbuff,
- size_t sendcount,
- datatype_t datatype,
- int root,
- cudaStream_t stream) const = 0;
-
- virtual void gatherv(const void* sendbuf,
- void* recvbuf,
- size_t sendcount,
- const size_t* recvcounts,
- const size_t* displs,
- datatype_t datatype,
- int root,
- cudaStream_t stream) const = 0;
-
- virtual void reducescatter(const void* sendbuff,
- void* recvbuff,
- size_t recvcount,
- datatype_t datatype,
- op_t op,
- cudaStream_t stream) const = 0;
-
- // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
- virtual void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const = 0;
-
- // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
- virtual void device_recv(void* buf, size_t size, int source, cudaStream_t stream) const = 0;
-
- virtual void device_sendrecv(const void* sendbuf,
- size_t sendsize,
- int dest,
- void* recvbuf,
- size_t recvsize,
- int source,
- cudaStream_t stream) const = 0;
-
- virtual void device_multicast_sendrecv(const void* sendbuf,
- std::vector const& sendsizes,
- std::vector const& sendoffsets,
- std::vector const& dests,
- void* recvbuf,
- std::vector const& recvsizes,
- std::vector const& recvoffsets,
- std::vector const& sources,
- cudaStream_t stream) const = 0;
-};
-
-class comms_t {
- public:
- comms_t(std::unique_ptr impl) : impl_(impl.release())
- {
- ASSERT(nullptr != impl_.get(), "ERROR: Invalid comms_iface used!");
- }
-
- /**
- * Virtual Destructor to enable polymorphism
- */
- virtual ~comms_t() {}
-
- /**
- * Returns the size of the communicator clique
- */
-
- int get_size() const { return impl_->get_size(); }
-
- /**
- * Returns the local rank
- */
- int get_rank() const { return impl_->get_rank(); }
-
- /**
- * Splits the current communicator clique into sub-cliques matching
- * the given color and key
- *
- * @param color ranks w/ the same color are placed in the same communicator
- * @param key controls rank assignment
- */
- std::unique_ptr comm_split(int color, int key) const
- {
- return impl_->comm_split(color, key);
- }
-
- /**
- * Performs a collective barrier synchronization
- */
- void barrier() const { impl_->barrier(); }
-
- /**
- * Some collective communications implementations (eg. NCCL) might use asynchronous
- * collectives that are explicitly synchronized. It's important to always synchronize
- * using this method to allow failures to propagate, rather than `cudaStreamSynchronize()`,
- * to prevent the potential for deadlocks.
- *
- * @param stream the cuda stream to sync collective operations on
- */
- status_t sync_stream(cudaStream_t stream) const { return impl_->sync_stream(stream); }
-
- /**
- * Performs an asynchronous point-to-point send
- * @tparam value_t the type of data to send
- * @param buf pointer to array of data to send
- * @param size number of elements in buf
- * @param dest destination rank
- * @param tag a tag to use for the receiver to filter
- * @param request pointer to hold returned request_t object.
- * This will be used in `waitall()` to synchronize until the message is delivered (or fails).
- */
- template
- void isend(const value_t* buf, size_t size, int dest, int tag, request_t* request) const
- {
- impl_->isend(static_cast(buf), size * sizeof(value_t), dest, tag, request);
- }
-
- /**
- * Performs an asynchronous point-to-point receive
- * @tparam value_t the type of data to be received
- * @param buf pointer to (initialized) array that will hold received data
- * @param size number of elements in buf
- * @param source source rank
- * @param tag a tag to use for message filtering
- * @param request pointer to hold returned request_t object.
- * This will be used in `waitall()` to synchronize until the message is delivered (or fails).
- */
- template
- void irecv(value_t* buf, size_t size, int source, int tag, request_t* request) const
- {
- impl_->irecv(static_cast(buf), size * sizeof(value_t), source, tag, request);
- }
-
- /**
- * Synchronize on an array of request_t objects returned from isend/irecv
- * @param count number of requests to synchronize on
- * @param array_of_requests an array of request_t objects returned from isend/irecv
- */
- void waitall(int count, request_t array_of_requests[]) const
- {
- impl_->waitall(count, array_of_requests);
- }
-
- /**
- * Perform an allreduce collective
- * @tparam value_t datatype of underlying buffers
- * @param sendbuff data to reduce
- * @param recvbuff buffer to hold the reduced result
- * @param count number of elements in sendbuff
- * @param op reduction operation to perform
- * @param stream CUDA stream to synchronize operation
- */
- template
- void allreduce(
- const value_t* sendbuff, value_t* recvbuff, size_t count, op_t op, cudaStream_t stream) const
- {
- impl_->allreduce(static_cast(sendbuff),
- static_cast(recvbuff),
- count,
- get_type(),
- op,
- stream);
- }
-
- /**
- * Broadcast data from one rank to the rest
- * @tparam value_t datatype of underlying buffers
- * @param buff buffer to send
- * @param count number of elements if buff
- * @param root the rank initiating the broadcast
- * @param stream CUDA stream to synchronize operation
- */
- template
- void bcast(value_t* buff, size_t count, int root, cudaStream_t stream) const
- {
- impl_->bcast(static_cast(buff), count, get_type(), root, stream);
- }
-
- /**
- * Broadcast data from one rank to the rest
- * @tparam value_t datatype of underlying buffers
- * @param sendbuff buffer containing data to broadcast (only used in root)
- * @param recvbuff buffer to receive broadcasted data
- * @param count number of elements if buff
- * @param root the rank initiating the broadcast
- * @param stream CUDA stream to synchronize operation
- */
- template
- void bcast(
- const value_t* sendbuff, value_t* recvbuff, size_t count, int root, cudaStream_t stream) const
- {
- impl_->bcast(static_cast(sendbuff),
- static_cast(recvbuff),
- count,
- get_type(),
- root,
- stream);
- }
-
- /**
- * Reduce data from many ranks down to a single rank
- * @tparam value_t datatype of underlying buffers
- * @param sendbuff buffer containing data to reduce
- * @param recvbuff buffer containing reduced data (only needs to be initialized on root)
- * @param count number of elements in sendbuff
- * @param op reduction operation to perform
- * @param root rank to store the results
- * @param stream CUDA stream to synchronize operation
- */
- template
- void reduce(const value_t* sendbuff,
- value_t* recvbuff,
- size_t count,
- op_t op,
- int root,
- cudaStream_t stream) const
- {
- impl_->reduce(static_cast(sendbuff),
- static_cast(recvbuff),
- count,
- get_type(),
- op,
- root,
- stream);
- }
-
- /**
- * Gathers data from each rank onto all ranks
- * @tparam value_t datatype of underlying buffers
- * @param sendbuff buffer containing data to gather
- * @param recvbuff buffer containing gathered data from all ranks
- * @param sendcount number of elements in send buffer
- * @param stream CUDA stream to synchronize operation
- */
- template
- void allgather(const value_t* sendbuff,
- value_t* recvbuff,
- size_t sendcount,
- cudaStream_t stream) const
- {
- impl_->allgather(static_cast(sendbuff),
- static_cast(recvbuff),
- sendcount,
- get_type(),
- stream);
- }
-
- /**
- * Gathers data from all ranks and delivers to combined data to all ranks
- * @tparam value_t datatype of underlying buffers
- * @param sendbuf buffer containing data to send
- * @param recvbuf buffer containing data to receive
- * @param recvcounts pointer to an array (of length num_ranks size) containing the number of
- * elements that are to be received from each rank
- * @param displs pointer to an array (of length num_ranks size) to specify the displacement
- * (relative to recvbuf) at which to place the incoming data from each rank
- * @param stream CUDA stream to synchronize operation
- */
- template
- void allgatherv(const value_t* sendbuf,
- value_t* recvbuf,
- const size_t* recvcounts,
- const size_t* displs,
- cudaStream_t stream) const
- {
- impl_->allgatherv(static_cast(sendbuf),
- static_cast(recvbuf),
- recvcounts,
- displs,
- get_type(),
- stream);
- }
-
- /**
- * Gathers data from each rank onto all ranks
- * @tparam value_t datatype of underlying buffers
- * @param sendbuff buffer containing data to gather
- * @param recvbuff buffer containing gathered data from all ranks
- * @param sendcount number of elements in send buffer
- * @param root rank to store the results
- * @param stream CUDA stream to synchronize operation
- */
- template
- void gather(const value_t* sendbuff,
- value_t* recvbuff,
- size_t sendcount,
- int root,
- cudaStream_t stream) const
- {
- impl_->gather(static_cast(sendbuff),
- static_cast(recvbuff),
- sendcount,
- get_type(),
- root,
- stream);
- }
-
- /**
- * Gathers data from all ranks and delivers to combined data to all ranks
- * @tparam value_t datatype of underlying buffers
- * @param sendbuf buffer containing data to send
- * @param recvbuf buffer containing data to receive
- * @param sendcount number of elements in send buffer
- * @param recvcounts pointer to an array (of length num_ranks size) containing the number of
- * elements that are to be received from each rank
- * @param displs pointer to an array (of length num_ranks size) to specify the displacement
- * (relative to recvbuf) at which to place the incoming data from each rank
- * @param root rank to store the results
- * @param stream CUDA stream to synchronize operation
- */
- template
- void gatherv(const value_t* sendbuf,
- value_t* recvbuf,
- size_t sendcount,
- const size_t* recvcounts,
- const size_t* displs,
- int root,
- cudaStream_t stream) const
- {
- impl_->gatherv(static_cast(sendbuf),
- static_cast(recvbuf),
- sendcount,
- recvcounts,
- displs,
- get_type(),
- root,
- stream);
- }
-
- /**
- * Reduces data from all ranks then scatters the result across ranks
- * @tparam value_t datatype of underlying buffers
- * @param sendbuff buffer containing data to send (size recvcount * num_ranks)
- * @param recvbuff buffer containing received data
- * @param recvcount number of items to receive
- * @param op reduction operation to perform
- * @param stream CUDA stream to synchronize operation
- */
- template
- void reducescatter(const value_t* sendbuff,
- value_t* recvbuff,
- size_t recvcount,
- op_t op,
- cudaStream_t stream) const
- {
- impl_->reducescatter(static_cast(sendbuff),
- static_cast(recvbuff),
- recvcount,
- get_type(),
- op,
- stream);
- }
-
- /**
- * Performs a point-to-point send
- *
- * if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock.
- *
- * @tparam value_t the type of data to send
- * @param buf pointer to array of data to send
- * @param size number of elements in buf
- * @param dest destination rank
- * @param stream CUDA stream to synchronize operation
- */
- template
- void device_send(const value_t* buf, size_t size, int dest, cudaStream_t stream) const
- {
- impl_->device_send(static_cast(buf), size * sizeof(value_t), dest, stream);
- }
-
- /**
- * Performs a point-to-point receive
- *
- * if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock.
- *
- * @tparam value_t the type of data to be received
- * @param buf pointer to (initialized) array that will hold received data
- * @param size number of elements in buf
- * @param source source rank
- * @param stream CUDA stream to synchronize operation
- */
- template
- void device_recv(value_t* buf, size_t size, int source, cudaStream_t stream) const
- {
- impl_->device_recv(static_cast(buf), size * sizeof(value_t), source, stream);
- }
-
- /**
- * Performs a point-to-point send/receive
- *
- * @tparam value_t the type of data to be sent & received
- * @param sendbuf pointer to array of data to send
- * @param sendsize number of elements in sendbuf
- * @param dest destination rank
- * @param recvbuf pointer to (initialized) array that will hold received data
- * @param recvsize number of elements in recvbuf
- * @param source source rank
- * @param stream CUDA stream to synchronize operation
- */
- template
- void device_sendrecv(const value_t* sendbuf,
- size_t sendsize,
- int dest,
- value_t* recvbuf,
- size_t recvsize,
- int source,
- cudaStream_t stream) const
- {
- impl_->device_sendrecv(static_cast(sendbuf),
- sendsize * sizeof(value_t),
- dest,
- static_cast(recvbuf),
- recvsize * sizeof(value_t),
- source,
- stream);
- }
-
- /**
- * Performs a multicast send/receive
- *
- * @tparam value_t the type of data to be sent & received
- * @param sendbuf pointer to array of data to send
- * @param sendsizes numbers of elements to send
- * @param sendoffsets offsets in a number of elements from sendbuf
- * @param dests destination ranks
- * @param recvbuf pointer to (initialized) array that will hold received data
- * @param recvsizes numbers of elements to recv
- * @param recvoffsets offsets in a number of elements from recvbuf
- * @param sources source ranks
- * @param stream CUDA stream to synchronize operation
- */
- template
- void device_multicast_sendrecv(const value_t* sendbuf,
- std::vector const& sendsizes,
- std::vector const& sendoffsets,
- std::vector const& dests,
- value_t* recvbuf,
- std::vector const& recvsizes,
- std::vector const& recvoffsets,
- std::vector const& sources,
- cudaStream_t stream) const
- {
- auto sendbytesizes = sendsizes;
- auto sendbyteoffsets = sendoffsets;
- for (size_t i = 0; i < sendsizes.size(); ++i) {
- sendbytesizes[i] *= sizeof(value_t);
- sendbyteoffsets[i] *= sizeof(value_t);
- }
- auto recvbytesizes = recvsizes;
- auto recvbyteoffsets = recvoffsets;
- for (size_t i = 0; i < recvsizes.size(); ++i) {
- recvbytesizes[i] *= sizeof(value_t);
- recvbyteoffsets[i] *= sizeof(value_t);
- }
- impl_->device_multicast_sendrecv(static_cast(sendbuf),
- sendbytesizes,
- sendbyteoffsets,
- dests,
- static_cast(recvbuf),
- recvbytesizes,
- recvbyteoffsets,
- sources,
- stream);
- }
-
- private:
- std::unique_ptr impl_;
-};
-
-} // namespace comms
-} // namespace raft
-
-#endif
+#include
diff --git a/cpp/include/raft/core/comms.hpp b/cpp/include/raft/core/comms.hpp
new file mode 100644
index 0000000000..bf2f7af777
--- /dev/null
+++ b/cpp/include/raft/core/comms.hpp
@@ -0,0 +1,633 @@
+/*
+ * Copyright (c) 2021-2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace raft {
+namespace comms {
+
+typedef unsigned int request_t;
+enum class datatype_t { CHAR, UINT8, INT32, UINT32, INT64, UINT64, FLOAT32, FLOAT64 };
+enum class op_t { SUM, PROD, MIN, MAX };
+
+/**
+ * The resulting status of distributed stream synchronization
+ */
+enum class status_t {
+ SUCCESS, // Synchronization successful
+ ERROR, // An error occured querying sync status
+ ABORT // A failure occurred in sync, queued operations aborted
+};
+
+template
+constexpr datatype_t
+
+get_type();
+
+template <>
+constexpr datatype_t
+
+get_type()
+{
+ return datatype_t::CHAR;
+}
+
+template <>
+constexpr datatype_t
+
+get_type()
+{
+ return datatype_t::UINT8;
+}
+
+template <>
+constexpr datatype_t
+
+get_type()
+{
+ return datatype_t::INT32;
+}
+
+template <>
+constexpr datatype_t
+
+get_type()
+{
+ return datatype_t::UINT32;
+}
+
+template <>
+constexpr datatype_t
+
+get_type()
+{
+ return datatype_t::INT64;
+}
+
+template <>
+constexpr datatype_t
+
+get_type()
+{
+ return datatype_t::UINT64;
+}
+
+template <>
+constexpr datatype_t
+
+get_type()
+{
+ return datatype_t::FLOAT32;
+}
+
+template <>
+constexpr datatype_t
+
+get_type()
+{
+ return datatype_t::FLOAT64;
+}
+
+class comms_iface {
+ public:
+ virtual ~comms_iface() {}
+
+ virtual int get_size() const = 0;
+
+ virtual int get_rank() const = 0;
+
+ virtual std::unique_ptr comm_split(int color, int key) const = 0;
+
+ virtual void barrier() const = 0;
+
+ virtual status_t sync_stream(cudaStream_t stream) const = 0;
+
+ virtual void isend(const void* buf, size_t size, int dest, int tag, request_t* request) const = 0;
+
+ virtual void irecv(void* buf, size_t size, int source, int tag, request_t* request) const = 0;
+
+ virtual void waitall(int count, request_t array_of_requests[]) const = 0;
+
+ virtual void allreduce(const void* sendbuff,
+ void* recvbuff,
+ size_t count,
+ datatype_t datatype,
+ op_t op,
+ cudaStream_t stream) const = 0;
+
+ virtual void bcast(
+ void* buff, size_t count, datatype_t datatype, int root, cudaStream_t stream) const = 0;
+
+ virtual void bcast(const void* sendbuff,
+ void* recvbuff,
+ size_t count,
+ datatype_t datatype,
+ int root,
+ cudaStream_t stream) const = 0;
+
+ virtual void reduce(const void* sendbuff,
+ void* recvbuff,
+ size_t count,
+ datatype_t datatype,
+ op_t op,
+ int root,
+ cudaStream_t stream) const = 0;
+
+ virtual void allgather(const void* sendbuff,
+ void* recvbuff,
+ size_t sendcount,
+ datatype_t datatype,
+ cudaStream_t stream) const = 0;
+
+ virtual void allgatherv(const void* sendbuf,
+ void* recvbuf,
+ const size_t* recvcounts,
+ const size_t* displs,
+ datatype_t datatype,
+ cudaStream_t stream) const = 0;
+
+ virtual void gather(const void* sendbuff,
+ void* recvbuff,
+ size_t sendcount,
+ datatype_t datatype,
+ int root,
+ cudaStream_t stream) const = 0;
+
+ virtual void gatherv(const void* sendbuf,
+ void* recvbuf,
+ size_t sendcount,
+ const size_t* recvcounts,
+ const size_t* displs,
+ datatype_t datatype,
+ int root,
+ cudaStream_t stream) const = 0;
+
+ virtual void reducescatter(const void* sendbuff,
+ void* recvbuff,
+ size_t recvcount,
+ datatype_t datatype,
+ op_t op,
+ cudaStream_t stream) const = 0;
+
+ // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
+ virtual void device_send(const void* buf, size_t size, int dest, cudaStream_t stream) const = 0;
+
+ // if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock
+ virtual void device_recv(void* buf, size_t size, int source, cudaStream_t stream) const = 0;
+
+ virtual void device_sendrecv(const void* sendbuf,
+ size_t sendsize,
+ int dest,
+ void* recvbuf,
+ size_t recvsize,
+ int source,
+ cudaStream_t stream) const = 0;
+
+ virtual void device_multicast_sendrecv(const void* sendbuf,
+ std::vector const& sendsizes,
+ std::vector const& sendoffsets,
+ std::vector const& dests,
+ void* recvbuf,
+ std::vector const& recvsizes,
+ std::vector const& recvoffsets,
+ std::vector const& sources,
+ cudaStream_t stream) const = 0;
+};
+
+class comms_t {
+ public:
+ comms_t(std::unique_ptr impl) : impl_(impl.release())
+ {
+ ASSERT(nullptr != impl_.get(), "ERROR: Invalid comms_iface used!");
+ }
+
+ /**
+ * Virtual Destructor to enable polymorphism
+ */
+ virtual ~comms_t() {}
+
+ /**
+ * Returns the size of the communicator clique
+ */
+
+ int get_size() const { return impl_->get_size(); }
+
+ /**
+ * Returns the local rank
+ */
+ int get_rank() const { return impl_->get_rank(); }
+
+ /**
+ * Splits the current communicator clique into sub-cliques matching
+ * the given color and key
+ *
+ * @param color ranks w/ the same color are placed in the same communicator
+ * @param key controls rank assignment
+ */
+ std::unique_ptr comm_split(int color, int key) const
+ {
+ return impl_->comm_split(color, key);
+ }
+
+ /**
+ * Performs a collective barrier synchronization
+ */
+ void barrier() const { impl_->barrier(); }
+
+ /**
+ * Some collective communications implementations (eg. NCCL) might use asynchronous
+ * collectives that are explicitly synchronized. It's important to always synchronize
+ * using this method to allow failures to propagate, rather than `cudaStreamSynchronize()`,
+ * to prevent the potential for deadlocks.
+ *
+ * @param stream the cuda stream to sync collective operations on
+ */
+ status_t sync_stream(cudaStream_t stream) const { return impl_->sync_stream(stream); }
+
+ /**
+ * Performs an asynchronous point-to-point send
+ * @tparam value_t the type of data to send
+ * @param buf pointer to array of data to send
+ * @param size number of elements in buf
+ * @param dest destination rank
+ * @param tag a tag to use for the receiver to filter
+ * @param request pointer to hold returned request_t object.
+ * This will be used in `waitall()` to synchronize until the message is delivered (or fails).
+ */
+ template
+ void isend(const value_t* buf, size_t size, int dest, int tag, request_t* request) const
+ {
+ impl_->isend(static_cast(buf), size * sizeof(value_t), dest, tag, request);
+ }
+
+ /**
+ * Performs an asynchronous point-to-point receive
+ * @tparam value_t the type of data to be received
+ * @param buf pointer to (initialized) array that will hold received data
+ * @param size number of elements in buf
+ * @param source source rank
+ * @param tag a tag to use for message filtering
+ * @param request pointer to hold returned request_t object.
+ * This will be used in `waitall()` to synchronize until the message is delivered (or fails).
+ */
+ template
+ void irecv(value_t* buf, size_t size, int source, int tag, request_t* request) const
+ {
+ impl_->irecv(static_cast(buf), size * sizeof(value_t), source, tag, request);
+ }
+
+ /**
+ * Synchronize on an array of request_t objects returned from isend/irecv
+ * @param count number of requests to synchronize on
+ * @param array_of_requests an array of request_t objects returned from isend/irecv
+ */
+ void waitall(int count, request_t array_of_requests[]) const
+ {
+ impl_->waitall(count, array_of_requests);
+ }
+
+ /**
+ * Perform an allreduce collective
+ * @tparam value_t datatype of underlying buffers
+ * @param sendbuff data to reduce
+ * @param recvbuff buffer to hold the reduced result
+ * @param count number of elements in sendbuff
+ * @param op reduction operation to perform
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void allreduce(
+ const value_t* sendbuff, value_t* recvbuff, size_t count, op_t op, cudaStream_t stream) const
+ {
+ impl_->allreduce(static_cast(sendbuff),
+ static_cast(recvbuff),
+ count,
+ get_type(),
+ op,
+ stream);
+ }
+
+ /**
+ * Broadcast data from one rank to the rest
+ * @tparam value_t datatype of underlying buffers
+ * @param buff buffer to send
+ * @param count number of elements if buff
+ * @param root the rank initiating the broadcast
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void bcast(value_t* buff, size_t count, int root, cudaStream_t stream) const
+ {
+ impl_->bcast(static_cast(buff), count, get_type(), root, stream);
+ }
+
+ /**
+ * Broadcast data from one rank to the rest
+ * @tparam value_t datatype of underlying buffers
+ * @param sendbuff buffer containing data to broadcast (only used in root)
+ * @param recvbuff buffer to receive broadcasted data
+ * @param count number of elements if buff
+ * @param root the rank initiating the broadcast
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void bcast(
+ const value_t* sendbuff, value_t* recvbuff, size_t count, int root, cudaStream_t stream) const
+ {
+ impl_->bcast(static_cast(sendbuff),
+ static_cast(recvbuff),
+ count,
+ get_type(),
+ root,
+ stream);
+ }
+
+ /**
+ * Reduce data from many ranks down to a single rank
+ * @tparam value_t datatype of underlying buffers
+ * @param sendbuff buffer containing data to reduce
+ * @param recvbuff buffer containing reduced data (only needs to be initialized on root)
+ * @param count number of elements in sendbuff
+ * @param op reduction operation to perform
+ * @param root rank to store the results
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void reduce(const value_t* sendbuff,
+ value_t* recvbuff,
+ size_t count,
+ op_t op,
+ int root,
+ cudaStream_t stream) const
+ {
+ impl_->reduce(static_cast(sendbuff),
+ static_cast(recvbuff),
+ count,
+ get_type(),
+ op,
+ root,
+ stream);
+ }
+
+ /**
+ * Gathers data from each rank onto all ranks
+ * @tparam value_t datatype of underlying buffers
+ * @param sendbuff buffer containing data to gather
+ * @param recvbuff buffer containing gathered data from all ranks
+ * @param sendcount number of elements in send buffer
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void allgather(const value_t* sendbuff,
+ value_t* recvbuff,
+ size_t sendcount,
+ cudaStream_t stream) const
+ {
+ impl_->allgather(static_cast(sendbuff),
+ static_cast(recvbuff),
+ sendcount,
+ get_type(),
+ stream);
+ }
+
+ /**
+ * Gathers data from all ranks and delivers to combined data to all ranks
+ * @tparam value_t datatype of underlying buffers
+ * @param sendbuf buffer containing data to send
+ * @param recvbuf buffer containing data to receive
+ * @param recvcounts pointer to an array (of length num_ranks size) containing the number of
+ * elements that are to be received from each rank
+ * @param displs pointer to an array (of length num_ranks size) to specify the displacement
+ * (relative to recvbuf) at which to place the incoming data from each rank
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void allgatherv(const value_t* sendbuf,
+ value_t* recvbuf,
+ const size_t* recvcounts,
+ const size_t* displs,
+ cudaStream_t stream) const
+ {
+ impl_->allgatherv(static_cast(sendbuf),
+ static_cast(recvbuf),
+ recvcounts,
+ displs,
+ get_type(),
+ stream);
+ }
+
+ /**
+ * Gathers data from each rank onto all ranks
+ * @tparam value_t datatype of underlying buffers
+ * @param sendbuff buffer containing data to gather
+ * @param recvbuff buffer containing gathered data from all ranks
+ * @param sendcount number of elements in send buffer
+ * @param root rank to store the results
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void gather(const value_t* sendbuff,
+ value_t* recvbuff,
+ size_t sendcount,
+ int root,
+ cudaStream_t stream) const
+ {
+ impl_->gather(static_cast(sendbuff),
+ static_cast(recvbuff),
+ sendcount,
+ get_type(),
+ root,
+ stream);
+ }
+
+ /**
+ * Gathers data from all ranks and delivers to combined data to all ranks
+ * @tparam value_t datatype of underlying buffers
+ * @param sendbuf buffer containing data to send
+ * @param recvbuf buffer containing data to receive
+ * @param sendcount number of elements in send buffer
+ * @param recvcounts pointer to an array (of length num_ranks size) containing the number of
+ * elements that are to be received from each rank
+ * @param displs pointer to an array (of length num_ranks size) to specify the displacement
+ * (relative to recvbuf) at which to place the incoming data from each rank
+ * @param root rank to store the results
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void gatherv(const value_t* sendbuf,
+ value_t* recvbuf,
+ size_t sendcount,
+ const size_t* recvcounts,
+ const size_t* displs,
+ int root,
+ cudaStream_t stream) const
+ {
+ impl_->gatherv(static_cast(sendbuf),
+ static_cast(recvbuf),
+ sendcount,
+ recvcounts,
+ displs,
+ get_type(),
+ root,
+ stream);
+ }
+
+ /**
+ * Reduces data from all ranks then scatters the result across ranks
+ * @tparam value_t datatype of underlying buffers
+ * @param sendbuff buffer containing data to send (size recvcount * num_ranks)
+ * @param recvbuff buffer containing received data
+ * @param recvcount number of items to receive
+ * @param op reduction operation to perform
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void reducescatter(const value_t* sendbuff,
+ value_t* recvbuff,
+ size_t recvcount,
+ op_t op,
+ cudaStream_t stream) const
+ {
+ impl_->reducescatter(static_cast(sendbuff),
+ static_cast(recvbuff),
+ recvcount,
+ get_type(),
+ op,
+ stream);
+ }
+
+ /**
+ * Performs a point-to-point send
+ *
+ * if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock.
+ *
+ * @tparam value_t the type of data to send
+ * @param buf pointer to array of data to send
+ * @param size number of elements in buf
+ * @param dest destination rank
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void device_send(const value_t* buf, size_t size, int dest, cudaStream_t stream) const
+ {
+ impl_->device_send(static_cast(buf), size * sizeof(value_t), dest, stream);
+ }
+
+ /**
+ * Performs a point-to-point receive
+ *
+ * if a thread is sending & receiving at the same time, use device_sendrecv to avoid deadlock.
+ *
+ * @tparam value_t the type of data to be received
+ * @param buf pointer to (initialized) array that will hold received data
+ * @param size number of elements in buf
+ * @param source source rank
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void device_recv(value_t* buf, size_t size, int source, cudaStream_t stream) const
+ {
+ impl_->device_recv(static_cast(buf), size * sizeof(value_t), source, stream);
+ }
+
+ /**
+ * Performs a point-to-point send/receive
+ *
+ * @tparam value_t the type of data to be sent & received
+ * @param sendbuf pointer to array of data to send
+ * @param sendsize number of elements in sendbuf
+ * @param dest destination rank
+ * @param recvbuf pointer to (initialized) array that will hold received data
+ * @param recvsize number of elements in recvbuf
+ * @param source source rank
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void device_sendrecv(const value_t* sendbuf,
+ size_t sendsize,
+ int dest,
+ value_t* recvbuf,
+ size_t recvsize,
+ int source,
+ cudaStream_t stream) const
+ {
+ impl_->device_sendrecv(static_cast(sendbuf),
+ sendsize * sizeof(value_t),
+ dest,
+ static_cast(recvbuf),
+ recvsize * sizeof(value_t),
+ source,
+ stream);
+ }
+
+ /**
+ * Performs a multicast send/receive
+ *
+ * @tparam value_t the type of data to be sent & received
+ * @param sendbuf pointer to array of data to send
+ * @param sendsizes numbers of elements to send
+ * @param sendoffsets offsets in a number of elements from sendbuf
+ * @param dests destination ranks
+ * @param recvbuf pointer to (initialized) array that will hold received data
+ * @param recvsizes numbers of elements to recv
+ * @param recvoffsets offsets in a number of elements from recvbuf
+ * @param sources source ranks
+ * @param stream CUDA stream to synchronize operation
+ */
+ template
+ void device_multicast_sendrecv(const value_t* sendbuf,
+ std::vector const& sendsizes,
+ std::vector const& sendoffsets,
+ std::vector const& dests,
+ value_t* recvbuf,
+ std::vector const& recvsizes,
+ std::vector const& recvoffsets,
+ std::vector const& sources,
+ cudaStream_t stream) const
+ {
+ auto sendbytesizes = sendsizes;
+ auto sendbyteoffsets = sendoffsets;
+ for (size_t i = 0; i < sendsizes.size(); ++i) {
+ sendbytesizes[i] *= sizeof(value_t);
+ sendbyteoffsets[i] *= sizeof(value_t);
+ }
+ auto recvbytesizes = recvsizes;
+ auto recvbyteoffsets = recvoffsets;
+ for (size_t i = 0; i < recvsizes.size(); ++i) {
+ recvbytesizes[i] *= sizeof(value_t);
+ recvbyteoffsets[i] *= sizeof(value_t);
+ }
+ impl_->device_multicast_sendrecv(static_cast(sendbuf),
+ sendbytesizes,
+ sendbyteoffsets,
+ dests,
+ static_cast(recvbuf),
+ recvbytesizes,
+ recvbyteoffsets,
+ sources,
+ stream);
+ }
+
+ private:
+ std::unique_ptr impl_;
+};
+
+} // namespace comms
+} // namespace raft
diff --git a/cpp/include/raft/core/cublas_macros.hpp b/cpp/include/raft/core/cublas_macros.hpp
new file mode 100644
index 0000000000..f5de57677d
--- /dev/null
+++ b/cpp/include/raft/core/cublas_macros.hpp
@@ -0,0 +1,121 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __RAFT_RT_CUBLAS_MACROS_H
+#define __RAFT_RT_CUBLAS_MACROS_H
+
+#pragma once
+
+#include
+#include
+
+///@todo: enable this once we have logger enabled
+//#include
+
+#include
+
+#define _CUBLAS_ERR_TO_STR(err) \
+ case err: return #err
+
+namespace raft {
+
+/**
+ * @brief Exception thrown when a cuBLAS error is encountered.
+ */
+struct cublas_error : public raft::exception {
+ explicit cublas_error(char const* const message) : raft::exception(message) {}
+ explicit cublas_error(std::string const& message) : raft::exception(message) {}
+};
+
+namespace linalg {
+namespace detail {
+
+inline const char* cublas_error_to_string(cublasStatus_t err)
+{
+ switch (err) {
+ _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_SUCCESS);
+ _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_NOT_INITIALIZED);
+ _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_ALLOC_FAILED);
+ _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_INVALID_VALUE);
+ _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_ARCH_MISMATCH);
+ _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_MAPPING_ERROR);
+ _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_EXECUTION_FAILED);
+ _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_INTERNAL_ERROR);
+ _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_NOT_SUPPORTED);
+ _CUBLAS_ERR_TO_STR(CUBLAS_STATUS_LICENSE_ERROR);
+ default: return "CUBLAS_STATUS_UNKNOWN";
+ };
+}
+
+} // namespace detail
+} // namespace linalg
+} // namespace raft
+
+#undef _CUBLAS_ERR_TO_STR
+
+/**
+ * @brief Error checking macro for cuBLAS runtime API functions.
+ *
+ * Invokes a cuBLAS runtime API function call, if the call does not return
+ * CUBLAS_STATUS_SUCCESS, throws an exception detailing the cuBLAS error that occurred
+ */
+#define RAFT_CUBLAS_TRY(call) \
+ do { \
+ cublasStatus_t const status = (call); \
+ if (CUBLAS_STATUS_SUCCESS != status) { \
+ std::string msg{}; \
+ SET_ERROR_MSG(msg, \
+ "cuBLAS error encountered at: ", \
+ "call='%s', Reason=%d:%s", \
+ #call, \
+ status, \
+ raft::linalg::detail::cublas_error_to_string(status)); \
+ throw raft::cublas_error(msg); \
+ } \
+ } while (0)
+
+// FIXME: Remove after consumers rename
+#ifndef CUBLAS_TRY
+#define CUBLAS_TRY(call) RAFT_CUBLAS_TRY(call)
+#endif
+
+// /**
+// * @brief check for cuda runtime API errors but log error instead of raising
+// * exception.
+// */
+#define RAFT_CUBLAS_TRY_NO_THROW(call) \
+ do { \
+ cublasStatus_t const status = call; \
+ if (CUBLAS_STATUS_SUCCESS != status) { \
+ printf("CUBLAS call='%s' at file=%s line=%d failed with %s\n", \
+ #call, \
+ __FILE__, \
+ __LINE__, \
+ raft::linalg::detail::cublas_error_to_string(status)); \
+ } \
+ } while (0)
+
+/** FIXME: remove after cuml rename */
+#ifndef CUBLAS_CHECK
+#define CUBLAS_CHECK(call) CUBLAS_TRY(call)
+#endif
+
+/** FIXME: remove after cuml rename */
+#ifndef CUBLAS_CHECK_NO_THROW
+#define CUBLAS_CHECK_NO_THROW(call) RAFT_CUBLAS_TRY_NO_THROW(call)
+#endif
+
+#endif
\ No newline at end of file
diff --git a/cpp/include/raft/core/cudart_utils.hpp b/cpp/include/raft/core/cudart_utils.hpp
new file mode 100644
index 0000000000..5adc0227a8
--- /dev/null
+++ b/cpp/include/raft/core/cudart_utils.hpp
@@ -0,0 +1,428 @@
+/*
+ * Copyright (c) 2019-2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * This file is deprecated and will be removed in release 22.06.
+ * Please use raft_runtime/cudart_utils.hpp instead.
+ */
+
+#ifndef __RAFT_RT_CUDART_UTILS_H
+#define __RAFT_RT_CUDART_UTILS_H
+
+#pragma once
+
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+///@todo: enable once logging has been enabled in raft
+//#include "logger.hpp"
+
+namespace raft {
+
+/**
+ * @brief Exception thrown when a CUDA error is encountered.
+ */
+struct cuda_error : public raft::exception {
+ explicit cuda_error(char const* const message) : raft::exception(message) {}
+ explicit cuda_error(std::string const& message) : raft::exception(message) {}
+};
+
+} // namespace raft
+
+/**
+ * @brief Error checking macro for CUDA runtime API functions.
+ *
+ * Invokes a CUDA runtime API function call, if the call does not return
+ * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an
+ * exception detailing the CUDA error that occurred
+ *
+ */
+#define RAFT_CUDA_TRY(call) \
+ do { \
+ cudaError_t const status = call; \
+ if (status != cudaSuccess) { \
+ cudaGetLastError(); \
+ std::string msg{}; \
+ SET_ERROR_MSG(msg, \
+ "CUDA error encountered at: ", \
+ "call='%s', Reason=%s:%s", \
+ #call, \
+ cudaGetErrorName(status), \
+ cudaGetErrorString(status)); \
+ throw raft::cuda_error(msg); \
+ } \
+ } while (0)
+
+// FIXME: Remove after consumers rename
+#ifndef CUDA_TRY
+#define CUDA_TRY(call) RAFT_CUDA_TRY(call)
+#endif
+
+/**
+ * @brief Debug macro to check for CUDA errors
+ *
+ * In a non-release build, this macro will synchronize the specified stream
+ * before error checking. In both release and non-release builds, this macro
+ * checks for any pending CUDA errors from previous calls. If an error is
+ * reported, an exception is thrown detailing the CUDA error that occurred.
+ *
+ * The intent of this macro is to provide a mechanism for synchronous and
+ * deterministic execution for debugging asynchronous CUDA execution. It should
+ * be used after any asynchronous CUDA call, e.g., cudaMemcpyAsync, or an
+ * asynchronous kernel launch.
+ */
+#ifndef NDEBUG
+#define RAFT_CHECK_CUDA(stream) RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
+#else
+#define RAFT_CHECK_CUDA(stream) RAFT_CUDA_TRY(cudaPeekAtLastError());
+#endif
+
+// FIXME: Remove after consumers rename
+#ifndef CHECK_CUDA
+#define CHECK_CUDA(call) RAFT_CHECK_CUDA(call)
+#endif
+
+/** FIXME: remove after cuml rename */
+#ifndef CUDA_CHECK
+#define CUDA_CHECK(call) RAFT_CUDA_TRY(call)
+#endif
+
+// /**
+// * @brief check for cuda runtime API errors but log error instead of raising
+// * exception.
+// */
+#define RAFT_CUDA_TRY_NO_THROW(call) \
+ do { \
+ cudaError_t const status = call; \
+ if (cudaSuccess != status) { \
+ printf("CUDA call='%s' at file=%s line=%d failed with %s\n", \
+ #call, \
+ __FILE__, \
+ __LINE__, \
+ cudaGetErrorString(status)); \
+ } \
+ } while (0)
+
+// FIXME: Remove after cuml rename
+#ifndef CUDA_CHECK_NO_THROW
+#define CUDA_CHECK_NO_THROW(call) RAFT_CUDA_TRY_NO_THROW(call)
+#endif
+
+/**
+ * Alias to raft scope for now.
+ * TODO: Rename original implementations in 22.04 to fix
+ * https://github.com/rapidsai/raft/issues/128
+ */
+
+namespace raft {
+
+/** Helper method to get to know warp size in device code */
+__host__ __device__ constexpr inline int warp_size() { return 32; }
+
+__host__ __device__ constexpr inline unsigned int warp_full_mask() { return 0xffffffff; }
+
+/**
+ * @brief A kernel grid configuration construction gadget for simple one-dimensional mapping
+ * elements to threads.
+ */
+class grid_1d_thread_t {
+ public:
+ int const block_size{0};
+ int const num_blocks{0};
+
+ /**
+ * @param overall_num_elements The number of elements the kernel needs to handle/process
+ * @param num_threads_per_block The grid block size, determined according to the kernel's
+ * specific features (amount of shared memory necessary, SM functional units use pattern etc.);
+ * this can't be determined generically/automatically (as opposed to the number of blocks)
+ * @param max_num_blocks_1d maximum number of blocks in 1d grid
+ * @param elements_per_thread Typically, a single kernel thread processes more than a single
+ * element; this affects the number of threads the grid must contain
+ */
+ grid_1d_thread_t(size_t overall_num_elements,
+ size_t num_threads_per_block,
+ size_t max_num_blocks_1d,
+ size_t elements_per_thread = 1)
+ : block_size(num_threads_per_block),
+ num_blocks(
+ std::min((overall_num_elements + (elements_per_thread * num_threads_per_block) - 1) /
+ (elements_per_thread * num_threads_per_block),
+ max_num_blocks_1d))
+ {
+ RAFT_EXPECTS(overall_num_elements > 0, "overall_num_elements must be > 0");
+ RAFT_EXPECTS(num_threads_per_block / warp_size() > 0,
+ "num_threads_per_block / warp_size() must be > 0");
+ RAFT_EXPECTS(elements_per_thread > 0, "elements_per_thread must be > 0");
+ }
+};
+
+/**
+ * @brief A kernel grid configuration construction gadget for simple one-dimensional mapping
+ * elements to warps.
+ */
+class grid_1d_warp_t {
+ public:
+ int const block_size{0};
+ int const num_blocks{0};
+
+ /**
+ * @param overall_num_elements The number of elements the kernel needs to handle/process
+ * @param num_threads_per_block The grid block size, determined according to the kernel's
+ * specific features (amount of shared memory necessary, SM functional units use pattern etc.);
+ * this can't be determined generically/automatically (as opposed to the number of blocks)
+ * @param max_num_blocks_1d maximum number of blocks in 1d grid
+ */
+ grid_1d_warp_t(size_t overall_num_elements,
+ size_t num_threads_per_block,
+ size_t max_num_blocks_1d)
+ : block_size(num_threads_per_block),
+ num_blocks(std::min((overall_num_elements + (num_threads_per_block / warp_size()) - 1) /
+ (num_threads_per_block / warp_size()),
+ max_num_blocks_1d))
+ {
+ RAFT_EXPECTS(overall_num_elements > 0, "overall_num_elements must be > 0");
+ RAFT_EXPECTS(num_threads_per_block / warp_size() > 0,
+ "num_threads_per_block / warp_size() must be > 0");
+ }
+};
+
+/**
+ * @brief A kernel grid configuration construction gadget for simple one-dimensional mapping
+ * elements to blocks.
+ */
+class grid_1d_block_t {
+ public:
+ int const block_size{0};
+ int const num_blocks{0};
+
+ /**
+ * @param overall_num_elements The number of elements the kernel needs to handle/process
+ * @param num_threads_per_block The grid block size, determined according to the kernel's
+ * specific features (amount of shared memory necessary, SM functional units use pattern etc.);
+ * this can't be determined generically/automatically (as opposed to the number of blocks)
+ * @param max_num_blocks_1d maximum number of blocks in 1d grid
+ */
+ grid_1d_block_t(size_t overall_num_elements,
+ size_t num_threads_per_block,
+ size_t max_num_blocks_1d)
+ : block_size(num_threads_per_block),
+ num_blocks(std::min(overall_num_elements, max_num_blocks_1d))
+ {
+ RAFT_EXPECTS(overall_num_elements > 0, "overall_num_elements must be > 0");
+ RAFT_EXPECTS(num_threads_per_block / warp_size() > 0,
+ "num_threads_per_block / warp_size() must be > 0");
+ }
+};
+
+/**
+ * @brief Generic copy method for all kinds of transfers
+ * @tparam Type data type
+ * @param dst destination pointer
+ * @param src source pointer
+ * @param len lenth of the src/dst buffers in terms of number of elements
+ * @param stream cuda stream
+ */
+template
+void copy(Type* dst, const Type* src, size_t len, rmm::cuda_stream_view stream)
+{
+ CUDA_CHECK(cudaMemcpyAsync(dst, src, len * sizeof(Type), cudaMemcpyDefault, stream));
+}
+
+/**
+ * @defgroup Copy Copy methods
+ * These are here along with the generic 'copy' method in order to improve
+ * code readability using explicitly specified function names
+ * @{
+ */
+/** performs a host to device copy */
+template
+void update_device(Type* d_ptr, const Type* h_ptr, size_t len, rmm::cuda_stream_view stream)
+{
+ copy(d_ptr, h_ptr, len, stream);
+}
+
+/** performs a device to host copy */
+template
+void update_host(Type* h_ptr, const Type* d_ptr, size_t len, rmm::cuda_stream_view stream)
+{
+ copy(h_ptr, d_ptr, len, stream);
+}
+
+template
+void copy_async(Type* d_ptr1, const Type* d_ptr2, size_t len, rmm::cuda_stream_view stream)
+{
+ CUDA_CHECK(cudaMemcpyAsync(d_ptr1, d_ptr2, len * sizeof(Type), cudaMemcpyDeviceToDevice, stream));
+}
+/** @} */
+
+/**
+ * @defgroup Debug Utils for debugging host/device buffers
+ * @{
+ */
+template
+void print_host_vector(const char* variable_name,
+ const T* host_mem,
+ size_t componentsCount,
+ OutStream& out)
+{
+ out << variable_name << "=[";
+ for (size_t i = 0; i < componentsCount; ++i) {
+ if (i != 0) out << ",";
+ out << host_mem[i];
+ }
+ out << "];\n";
+}
+
+template
+void print_device_vector(const char* variable_name,
+ const T* devMem,
+ size_t componentsCount,
+ OutStream& out)
+{
+ T* host_mem = new T[componentsCount];
+ CUDA_CHECK(cudaMemcpy(host_mem, devMem, componentsCount * sizeof(T), cudaMemcpyDeviceToHost));
+ print_host_vector(variable_name, host_mem, componentsCount, out);
+ delete[] host_mem;
+}
+/** @} */
+
+/** helper method to get max usable shared mem per block parameter */
+inline int getSharedMemPerBlock()
+{
+ int devId;
+ RAFT_CUDA_TRY(cudaGetDevice(&devId));
+ int smemPerBlk;
+ RAFT_CUDA_TRY(cudaDeviceGetAttribute(&smemPerBlk, cudaDevAttrMaxSharedMemoryPerBlock, devId));
+ return smemPerBlk;
+}
+
+/** helper method to get multi-processor count parameter */
+inline int getMultiProcessorCount()
+{
+ int devId;
+ RAFT_CUDA_TRY(cudaGetDevice(&devId));
+ int mpCount;
+ RAFT_CUDA_TRY(cudaDeviceGetAttribute(&mpCount, cudaDevAttrMultiProcessorCount, devId));
+ return mpCount;
+}
+
+/** helper method to convert an array on device to a string on host */
+template
+std::string arr2Str(const T* arr, int size, std::string name, cudaStream_t stream, int width = 4)
+{
+ std::stringstream ss;
+
+ T* arr_h = (T*)malloc(size * sizeof(T));
+ update_host(arr_h, arr, size, stream);
+ RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
+
+ ss << name << " = [ ";
+ for (int i = 0; i < size; i++) {
+ ss << std::setw(width) << arr_h[i];
+
+ if (i < size - 1) ss << ", ";
+ }
+ ss << " ]" << std::endl;
+
+ free(arr_h);
+
+ return ss.str();
+}
+
+/** this seems to be unused, but may be useful in the future */
+template
+void ASSERT_DEVICE_MEM(T* ptr, std::string name)
+{
+ cudaPointerAttributes s_att;
+ cudaError_t s_err = cudaPointerGetAttributes(&s_att, ptr);
+
+ if (s_err != 0 || s_att.device == -1)
+ std::cout << "Invalid device pointer encountered in " << name << ". device=" << s_att.device
+ << ", err=" << s_err << std::endl;
+}
+
+inline uint32_t curTimeMillis()
+{
+ auto now = std::chrono::high_resolution_clock::now();
+ auto duration = now.time_since_epoch();
+ return std::chrono::duration_cast(duration).count();
+}
+
+/** Helper function to calculate need memory for allocate to store dense matrix.
+ * @param rows number of rows in matrix
+ * @param columns number of columns in matrix
+ * @return need number of items to allocate via allocate()
+ * @sa allocate()
+ */
+inline size_t allocLengthForMatrix(size_t rows, size_t columns) { return rows * columns; }
+
+/** Helper function to check alignment of pointer.
+ * @param ptr the pointer to check
+ * @param alignment to be checked for
+ * @return true if address in bytes is a multiple of alignment
+ */
+template
+bool is_aligned(Type* ptr, size_t alignment)
+{
+ return reinterpret_cast(ptr) % alignment == 0;
+}
+
+/** calculate greatest common divisor of two numbers
+ * @a integer
+ * @b integer
+ * @ return gcd of a and b
+ */
+template
+IntType gcd(IntType a, IntType b)
+{
+ while (b != 0) {
+ IntType tmp = b;
+ b = a % b;
+ a = tmp;
+ }
+ return a;
+}
+
+template
+constexpr T lower_bound()
+{
+ if constexpr (std::numeric_limits::has_infinity && std::numeric_limits::is_signed) {
+ return -std::numeric_limits::infinity();
+ }
+ return std::numeric_limits::lowest();
+}
+
+template
+constexpr T upper_bound()
+{
+ if constexpr (std::numeric_limits::has_infinity) { return std::numeric_limits::infinity(); }
+ return std::numeric_limits::max();
+}
+
+} // namespace raft
+
+#endif
\ No newline at end of file
diff --git a/cpp/include/raft/core/cusolver_macros.hpp b/cpp/include/raft/core/cusolver_macros.hpp
new file mode 100644
index 0000000000..b41927f5fb
--- /dev/null
+++ b/cpp/include/raft/core/cusolver_macros.hpp
@@ -0,0 +1,119 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __RAFT_RT_CUSOLVER_MACROS_H
+#define __RAFT_RT_CUSOLVER_MACROS_H
+
+#pragma once
+
+#include
+#include
+///@todo: enable this once logging is enabled
+//#include
+#include
+#include
+
+#define _CUSOLVER_ERR_TO_STR(err) \
+ case err: return #err;
+
+namespace raft {
+
+/**
+ * @brief Exception thrown when a cuSOLVER error is encountered.
+ */
+struct cusolver_error : public raft::exception {
+ explicit cusolver_error(char const* const message) : raft::exception(message) {}
+ explicit cusolver_error(std::string const& message) : raft::exception(message) {}
+};
+
+namespace linalg {
+namespace detail {
+
+inline const char* cusolver_error_to_string(cusolverStatus_t err)
+{
+ switch (err) {
+ _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_SUCCESS);
+ _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_NOT_INITIALIZED);
+ _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_ALLOC_FAILED);
+ _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_INVALID_VALUE);
+ _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_ARCH_MISMATCH);
+ _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_EXECUTION_FAILED);
+ _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_INTERNAL_ERROR);
+ _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED);
+ _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_ZERO_PIVOT);
+ _CUSOLVER_ERR_TO_STR(CUSOLVER_STATUS_NOT_SUPPORTED);
+ default: return "CUSOLVER_STATUS_UNKNOWN";
+ };
+}
+
+} // namespace detail
+} // namespace linalg
+} // namespace raft
+
+#undef _CUSOLVER_ERR_TO_STR
+
+/**
+ * @brief Error checking macro for cuSOLVER runtime API functions.
+ *
+ * Invokes a cuSOLVER runtime API function call, if the call does not return
+ * CUSolver_STATUS_SUCCESS, throws an exception detailing the cuSOLVER error that occurred
+ */
+#define RAFT_CUSOLVER_TRY(call) \
+ do { \
+ cusolverStatus_t const status = (call); \
+ if (CUSOLVER_STATUS_SUCCESS != status) { \
+ std::string msg{}; \
+ SET_ERROR_MSG(msg, \
+ "cuSOLVER error encountered at: ", \
+ "call='%s', Reason=%d:%s", \
+ #call, \
+ status, \
+ raft::linalg::detail::cusolver_error_to_string(status)); \
+ throw raft::cusolver_error(msg); \
+ } \
+ } while (0)
+
+// FIXME: remove after consumer rename
+#ifndef CUSOLVER_TRY
+#define CUSOLVER_TRY(call) RAFT_CUSOLVER_TRY(call)
+#endif
+
+// /**
+// * @brief check for cuda runtime API errors but log error instead of raising
+// * exception.
+// */
+#define RAFT_CUSOLVER_TRY_NO_THROW(call) \
+ do { \
+ cusolverStatus_t const status = call; \
+ if (CUSOLVER_STATUS_SUCCESS != status) { \
+ printf("CUSOLVER call='%s' at file=%s line=%d failed with %s\n", \
+ #call, \
+ __FILE__, \
+ __LINE__, \
+ raft::linalg::detail::cusolver_error_to_string(status)); \
+ } \
+ } while (0)
+
+// FIXME: remove after cuml rename
+#ifndef CUSOLVER_CHECK
+#define CUSOLVER_CHECK(call) CUSOLVER_TRY(call)
+#endif
+
+#ifndef CUSOLVER_CHECK_NO_THROW
+#define CUSOLVER_CHECK_NO_THROW(call) CUSOLVER_TRY_NO_THROW(call)
+#endif
+
+#endif
\ No newline at end of file
diff --git a/cpp/include/raft/core/cusparse_macros.hpp b/cpp/include/raft/core/cusparse_macros.hpp
new file mode 100644
index 0000000000..10c7e8836c
--- /dev/null
+++ b/cpp/include/raft/core/cusparse_macros.hpp
@@ -0,0 +1,123 @@
+/*
+ * Copyright (c) 2019-2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+///@todo: enable this once logging is enabled
+//#include
+
+#define _CUSPARSE_ERR_TO_STR(err) \
+ case err: return #err;
+
+// Notes:
+//(1.) CUDA_VER_10_1_UP aggregates all the CUDA version selection logic;
+//(2.) to enforce a lower version,
+//
+//`#define CUDA_ENFORCE_LOWER
+// #include