diff --git a/README.md b/README.md index bb268a896a..8e0da6cd6d 100755 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ While not exhaustive, the following general categories help summarize the accele All of RAFT's C++ APIs can be accessed header-only and optional pre-compiled shared libraries can 1) speed up compile times and 2) enable the APIs to be used without CUDA-enabled compilers. In addition to the C++ library, RAFT also provides 2 Python libraries: -- `pylibraft` - lightweight low-level Python wrappers around RAFT's host-accessible APIs. +- `pylibraft` - lightweight low-level Python wrappers around RAFT's host-accessible "runtime" APIs. - `raft-dask` - multi-node multi-GPU communicator infrastructure for building distributed algorithms on the GPU with Dask. ## Getting started @@ -142,7 +142,7 @@ in2 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32) output = pairwise_distance(in1, in2, metric="euclidean") ``` -The `output` array supports [__cuda_array_interface__](https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html#cuda-array-interface-version-2) so it is interoperable with other libraries like CuPy, Numba, and PyTorch that also support it. +The `output` array in the above example is of type `raft.common.device_ndarray`, which supports [__cuda_array_interface__](https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html#cuda-array-interface-version-2) making it interoperable with other libraries like CuPy, Numba, and PyTorch that also support it. CuPy supports DLPack, which also enables zero-copy conversion from `raft.common.device_ndarray` to JAX and Tensorflow. Below is an example of converting the output `pylibraft.device_ndarray` to a CuPy array: ```python @@ -156,6 +156,18 @@ import torch torch_tensor = torch.as_tensor(output, device='cuda') ``` +When the corresponding library has been installed and available in your environment, this conversion can also be done automatically by all RAFT compute APIs by setting a global configuration option: +```python +import pylibraft.config +pylibraft.config.set_output_as("cupy") # All compute APIs will return cupy arrays +pylibraft.config.set_output_as("torch") # All compute APIs will return torch tensors +``` + +You can also specify a `callable` that accepts a `pylibraft.common.device_ndarray` and performs a custom conversion. The following example converts all output to `numpy` arrays: +```python +pylibraft.config.set_output_as(lambda device_ndarray: return device_ndarray.copy_to_host()) +``` + `pylibraft` also supports writing to a pre-allocated output array so any `__cuda_array_interface__` supported array can be written to in-place: ```python @@ -257,7 +269,8 @@ Several CMake targets can be made available by adding components in the table be | --- | --- | --- | --- | | n/a | `raft::raft` | Full RAFT header library | CUDA toolkit library, RMM, Thrust (optional), NVTools (optional) | | distance | `raft::distance` | Pre-compiled template specializations for raft::distance | raft::raft, cuCollections (optional) | -| nn | `raft::nn` | Pre-compiled template specializations for raft::spatial::knn | raft::raft, FAISS (optional) | +| nn | `raft::nn` | Pre-compiled template specializations for raft::neighbors | raft::raft, FAISS (optional) | +| distributed | `raft::distributed` | No specializations | raft::raft, UCX, NCCL | ### Source diff --git a/cpp/include/raft/comms/helper.hpp b/cpp/include/raft/comms/helper.hpp deleted file mode 100644 index f6b63ac971..0000000000 --- a/cpp/include/raft/comms/helper.hpp +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include -#include - -namespace raft { -namespace comms { - -/** - * Function to construct comms_t and inject it on a handle_t. This - * is used for convenience in the Python layer. - * - * @param handle raft::handle_t for injecting the comms - * @param nccl_comm initialized NCCL communicator to use for collectives - * @param num_ranks number of ranks in communicator clique - * @param rank rank of local instance - */ -void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks, int rank) -{ - cudaStream_t stream = handle->get_stream(); - - auto communicator = std::make_shared( - std::unique_ptr(new raft::comms::std_comms(nccl_comm, num_ranks, rank, stream))); - handle->set_comms(communicator); -} - -/** - * Function to construct comms_t and inject it on a handle_t. This - * is used for convenience in the Python layer. - * - * @param handle raft::handle_t for injecting the comms - * @param nccl_comm initialized NCCL communicator to use for collectives - * @param ucp_worker of local process - * Note: This is purposefully left as void* so that the ucp_worker_h - * doesn't need to be exposed through the cython layer - * @param eps array of ucp_ep_h instances. - * Note: This is purposefully left as void* so that - * the ucp_ep_h doesn't need to be exposed through the cython layer. - * @param num_ranks number of ranks in communicator clique - * @param rank rank of local instance - */ -void build_comms_nccl_ucx( - handle_t* handle, ncclComm_t nccl_comm, void* ucp_worker, void* eps, int num_ranks, int rank) -{ - auto eps_sp = std::make_shared(new ucp_ep_h[num_ranks]); - - auto size_t_ep_arr = reinterpret_cast(eps); - - for (int i = 0; i < num_ranks; i++) { - size_t ptr = size_t_ep_arr[i]; - auto ucp_ep_v = reinterpret_cast(*eps_sp); - - if (ptr != 0) { - auto eps_ptr = reinterpret_cast(size_t_ep_arr[i]); - ucp_ep_v[i] = eps_ptr; - } else { - ucp_ep_v[i] = nullptr; - } - } - - cudaStream_t stream = handle->get_stream(); - - auto communicator = - std::make_shared(std::unique_ptr(new raft::comms::std_comms( - nccl_comm, (ucp_worker_h)ucp_worker, eps_sp, num_ranks, rank, stream))); - handle->set_comms(communicator); -} - -inline void nccl_unique_id_from_char(ncclUniqueId* id, char* uniqueId, int size) -{ - memcpy(id->internal, uniqueId, size); -} - -inline void get_unique_id(char* uid, int size) -{ - ncclUniqueId id; - ncclGetUniqueId(&id); - - memcpy(uid, id.internal, size); -} -}; // namespace comms -}; // end namespace raft diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index ca5275cd06..b3ea62efd2 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -24,6 +24,37 @@ namespace comms { using mpi_comms = detail::mpi_comms; +/** + * @defgroup mpi_comms_factory MPI Comms Factory Functions + * @{ + */ + +/** + * Given a properly initialized MPI_Comm, construct an instance of RAFT's + * MPI Communicator and inject it into the given RAFT handle instance + * @param handle raft handle for managing expensive resources + * @param comm an initialized MPI communicator + * + * @code{.cpp} + * #include + * #include + * + * MPI_Comm mpi_comm; + * raft::handle_t handle; + * + * initialize_mpi_comms(&handle, mpi_comm); + * ... + * const auto& comm = handle.get_comms(); + * auto gather_data = raft::make_device_vector(handle, comm.get_size()); + * ... + * comm.allgather((gather_data.data_handle())[comm.get_rank()], + * gather_data.data_handle(), + * 1, + * handle.get_stream()); + * + * comm.sync_stream(handle.get_stream()); + * @endcode + */ inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm) { auto communicator = std::make_shared( @@ -31,5 +62,9 @@ inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm) handle->set_comms(communicator); }; +/** + * @} + */ + }; // namespace comms }; // end namespace raft diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index edace60fbd..5e619053da 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -31,13 +31,38 @@ namespace comms { using std_comms = detail::std_comms; /** - * Function to construct comms_t and inject it on a handle_t. This - * is used for convenience in the Python layer. + * @defgroup std_comms_factory std_comms Factory functions + * @{ + */ + +/** + * Factory function to construct a RAFT NCCL communicator and inject it into a + * RAFT handle. * * @param handle raft::handle_t for injecting the comms * @param nccl_comm initialized NCCL communicator to use for collectives * @param num_ranks number of ranks in communicator clique * @param rank rank of local instance + * + * @code{.cpp} + * #include + * #include + * + * ncclComm_t nccl_comm; + * raft::handle_t handle; + * + * build_comms_nccl_only(&handle, nccl_comm, 5, 0); + * ... + * const auto& comm = handle.get_comms(); + * auto gather_data = raft::make_device_vector(handle, comm.get_size()); + * ... + * comm.allgather((gather_data.data_handle())[comm.get_rank()], + * gather_data.data_handle(), + * 1, + * handle.get_stream()); + * + * comm.sync_stream(handle.get_stream()); + * @endcode */ void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks, int rank) { @@ -49,8 +74,8 @@ void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks } /** - * Function to construct comms_t and inject it on a handle_t. This - * is used for convenience in the Python layer. + * Factory function to construct a RAFT NCCL+UCX and inject it into a RAFT + * handle. * * @param handle raft::handle_t for injecting the comms * @param nccl_comm initialized NCCL communicator to use for collectives @@ -62,6 +87,28 @@ void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks * the ucp_ep_h doesn't need to be exposed through the cython layer. * @param num_ranks number of ranks in communicator clique * @param rank rank of local instance + * + * @code{.cpp} + * #include + * #include + * + * ncclComm_t nccl_comm; + * raft::handle_t handle; + * ucp_worker_h ucp_worker; + * ucp_ep_h *ucp_endpoints_arr; + * + * build_comms_nccl_ucx(&handle, nccl_comm, &ucp_worker, ucp_endpoints_arr, 5, 0); + * ... + * const auto& comm = handle.get_comms(); + * auto gather_data = raft::make_device_vector(handle, comm.get_size()); + * ... + * comm.allgather((gather_data.data_handle())[comm.get_rank()], + * gather_data.data_handle(), + * 1, + * handle.get_stream()); + * + * comm.sync_stream(handle.get_stream()); + * @endcode */ void build_comms_nccl_ucx( handle_t* handle, ncclComm_t nccl_comm, void* ucp_worker, void* eps, int num_ranks, int rank) @@ -90,6 +137,10 @@ void build_comms_nccl_ucx( handle->set_comms(communicator); } +/** + * @} + */ + inline void nccl_unique_id_from_char(ncclUniqueId* id, char* uniqueId, int size) { memcpy(id->internal, uniqueId, size); diff --git a/cpp/include/raft/core/comms.hpp b/cpp/include/raft/core/comms.hpp index 78ce91dbf2..35ab6680de 100644 --- a/cpp/include/raft/core/comms.hpp +++ b/cpp/include/raft/core/comms.hpp @@ -23,6 +23,11 @@ namespace raft { namespace comms { +/** + * @defgroup comms_types Common mnmg comms types + * @{ + */ + typedef unsigned int request_t; enum class datatype_t { CHAR, UINT8, INT32, UINT32, INT64, UINT64, FLOAT32, FLOAT64 }; enum class op_t { SUM, PROD, MIN, MAX }; @@ -105,6 +110,15 @@ get_type() return datatype_t::FLOAT64; } +/** + * @} + */ + +/** + * @defgroup comms_iface MNMG Communicator Interface + * @{ + */ + class comms_iface { public: virtual ~comms_iface() {} @@ -215,6 +229,15 @@ class comms_iface { virtual void group_end() const = 0; }; +/** + * @} + */ + +/** + * @defgroup comms_t Base Communicator Proxy + * @{ + */ + class comms_t { public: comms_t(std::unique_ptr impl) : impl_(impl.release()) @@ -647,5 +670,9 @@ class comms_t { std::unique_ptr impl_; }; +/** + * @} + */ + } // namespace comms } // namespace raft diff --git a/cpp/include/raft/core/cublas_macros.hpp b/cpp/include/raft/core/cublas_macros.hpp index d2456433ab..855c1228f7 100644 --- a/cpp/include/raft/core/cublas_macros.hpp +++ b/cpp/include/raft/core/cublas_macros.hpp @@ -32,6 +32,11 @@ namespace raft { +/** + * @ingroup error_handling + * @{ + */ + /** * @brief Exception thrown when a cuBLAS error is encountered. */ @@ -40,6 +45,10 @@ struct cublas_error : public raft::exception { explicit cublas_error(std::string const& message) : raft::exception(message) {} }; +/** + * @} + */ + namespace linalg { namespace detail { @@ -66,6 +75,11 @@ inline const char* cublas_error_to_string(cublasStatus_t err) #undef _CUBLAS_ERR_TO_STR +/** + * @ingroup assertion + * @{ + */ + /** * @brief Error checking macro for cuBLAS runtime API functions. * @@ -108,6 +122,9 @@ inline const char* cublas_error_to_string(cublasStatus_t err) } \ } while (0) +/** + * @} + */ /** FIXME: remove after cuml rename */ #ifndef CUBLAS_CHECK #define CUBLAS_CHECK(call) CUBLAS_TRY(call) diff --git a/cpp/include/raft/core/cusolver_macros.hpp b/cpp/include/raft/core/cusolver_macros.hpp index 505485e6a0..8f7caf65f3 100644 --- a/cpp/include/raft/core/cusolver_macros.hpp +++ b/cpp/include/raft/core/cusolver_macros.hpp @@ -31,6 +31,11 @@ namespace raft { +/** + * @ingroup error_handling + * @{ + */ + /** * @brief Exception thrown when a cuSOLVER error is encountered. */ @@ -39,6 +44,10 @@ struct cusolver_error : public raft::exception { explicit cusolver_error(std::string const& message) : raft::exception(message) {} }; +/** + * @} + */ + namespace linalg { namespace detail { @@ -65,6 +74,11 @@ inline const char* cusolver_error_to_string(cusolverStatus_t err) #undef _CUSOLVER_ERR_TO_STR +/** + * @ingroup assertion + * @{ + */ + /** * @brief Error checking macro for cuSOLVER runtime API functions. * @@ -107,6 +121,10 @@ inline const char* cusolver_error_to_string(cusolverStatus_t err) } \ } while (0) +/** + * @} + */ + // FIXME: remove after cuml rename #ifndef CUSOLVER_CHECK #define CUSOLVER_CHECK(call) CUSOLVER_TRY(call) diff --git a/cpp/include/raft/core/cusparse_macros.hpp b/cpp/include/raft/core/cusparse_macros.hpp index cf5195582b..8a9aab55f7 100644 --- a/cpp/include/raft/core/cusparse_macros.hpp +++ b/cpp/include/raft/core/cusparse_macros.hpp @@ -37,6 +37,11 @@ namespace raft { +/** + * @ingroup error_handling + * @{ + */ + /** * @brief Exception thrown when a cuSparse error is encountered. */ @@ -45,6 +50,9 @@ struct cusparse_error : public raft::exception { explicit cusparse_error(std::string const& message) : raft::exception(message) {} }; +/** + * @} + */ namespace sparse { namespace detail { @@ -73,6 +81,11 @@ inline const char* cusparse_error_to_string(cusparseStatus_t err) #undef _CUSPARSE_ERR_TO_STR +/** + * @ingroup assertion + * @{ + */ + /** * @brief Error checking macro for cuSparse runtime API functions. * @@ -94,6 +107,10 @@ inline const char* cusparse_error_to_string(cusparseStatus_t err) } \ } while (0) +/** + * @} + */ + // FIXME: Remove after consumer rename #ifndef CUSPARSE_TRY #define CUSPARSE_TRY(call) RAFT_CUSPARSE_TRY(call) @@ -104,6 +121,10 @@ inline const char* cusparse_error_to_string(cusparseStatus_t err) #define CUSPARSE_CHECK(call) CUSPARSE_TRY(call) #endif +/** + * @ingroup assertion + * @{ + */ //@todo: use logger here once logging is enabled /** check for cusparse runtime API errors but do not assert */ #define RAFT_CUSPARSE_TRY_NO_THROW(call) \ @@ -117,6 +138,10 @@ inline const char* cusparse_error_to_string(cusparseStatus_t err) } \ } while (0) +/** + * @} + */ + // FIXME: Remove after consumer rename #ifndef CUSPARSE_CHECK_NO_THROW #define CUSPARSE_CHECK_NO_THROW(call) RAFT_CUSPARSE_TRY_NO_THROW(call) diff --git a/cpp/include/raft/core/error.hpp b/cpp/include/raft/core/error.hpp index b932309d24..84b244f4dc 100644 --- a/cpp/include/raft/core/error.hpp +++ b/cpp/include/raft/core/error.hpp @@ -30,6 +30,11 @@ namespace raft { +/** + * @defgroup error_handling Exceptions & Error Handling + * @{ + */ + /** base exception class for the whole of raft */ class exception : public std::exception { public: @@ -93,6 +98,10 @@ struct logic_error : public raft::exception { explicit logic_error(std::string const& message) : raft::exception(message) {} }; +/** + * @} + */ + } // namespace raft // FIXME: Need to be replaced with RAFT_FAIL @@ -143,6 +152,11 @@ struct logic_error : public raft::exception { msg += std::string(buf.data(), buf.data() + size - 1); /* -1 to remove final '\0' */ \ } while (0) +/** + * @defgroup assertion Assertion and error macros + * @{ + */ + /** * @brief Macro for checking (pre-)conditions that throws an exception when a condition is false * @@ -174,4 +188,8 @@ struct logic_error : public raft::exception { throw raft::logic_error(msg); \ } while (0) +/** + * @} + */ + #endif diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index 55e1955ae5..de521cc945 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -27,7 +27,7 @@ namespace raft { /** - * @defgroup Functors Commonly used functors. + * @defgroup operators Commonly used functors. * The optional unused arguments are useful for kernels that pass the index along with the value. * @{ */ diff --git a/docs/source/cpp_api.rst b/docs/source/cpp_api.rst index 04656d5047..0e82d81e35 100644 --- a/docs/source/cpp_api.rst +++ b/docs/source/cpp_api.rst @@ -13,6 +13,7 @@ C++ API cpp_api/linalg.rst cpp_api/matrix.rst cpp_api/mdspan.rst + cpp_api/mnmg.rst cpp_api/neighbors.rst cpp_api/random.rst cpp_api/solver.rst diff --git a/docs/source/cpp_api/cluster.rst b/docs/source/cpp_api/cluster.rst index 777977a488..77c8332bbd 100644 --- a/docs/source/cpp_api/cluster.rst +++ b/docs/source/cpp_api/cluster.rst @@ -1,41 +1,17 @@ Cluster ======= -This page provides C++ class references for the publicly-exposed elements of the `raft/cluster` headers. RAFT provides +This page provides C++ API references for the publicly-exposed elements of the `raft/cluster` headers. RAFT provides fundamental clustering algorithms which are, themselves, considered reusable building blocks for other algorithms. .. role:: py(code) :language: c++ :class: highlight -K-Means -####### +.. toctree:: + :maxdepth: 2 + :caption: Contents: -``#include `` - -.. doxygennamespace:: raft::cluster::kmeans - :project: RAFT - :members: - :content-only: - - -Hierarchical Clustering -####################### - -``#include `` - -.. doxygennamespace:: raft::cluster::hierarchy - :project: RAFT - :members: - :content-only: - - -Spectral Clustering -################### - -``#include `` - -.. doxygennamespace:: raft::spectral - :project: RAFT - :members: - :content-only: + cluster_kmeans.rst + cluster_slhc.rst + cluster_spectral.rst \ No newline at end of file diff --git a/docs/source/cpp_api/cluster_kmeans.rst b/docs/source/cpp_api/cluster_kmeans.rst new file mode 100644 index 0000000000..fa040ddc18 --- /dev/null +++ b/docs/source/cpp_api/cluster_kmeans.rst @@ -0,0 +1,13 @@ +K-Means +======= + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +.. doxygennamespace:: raft::cluster::kmeans + :project: RAFT + :members: + :content-only: diff --git a/docs/source/cpp_api/cluster_slhc.rst b/docs/source/cpp_api/cluster_slhc.rst new file mode 100644 index 0000000000..fc45ae699a --- /dev/null +++ b/docs/source/cpp_api/cluster_slhc.rst @@ -0,0 +1,13 @@ +Hierarchical Clustering +======================= + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +.. doxygennamespace:: raft::cluster::hierarchy + :project: RAFT + :members: + :content-only: diff --git a/docs/source/cpp_api/cluster_spectral.rst b/docs/source/cpp_api/cluster_spectral.rst new file mode 100644 index 0000000000..a71f431ab8 --- /dev/null +++ b/docs/source/cpp_api/cluster_spectral.rst @@ -0,0 +1,13 @@ +Spectral Clustering +=================== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +.. doxygennamespace:: raft::spectral + :project: RAFT + :members: + :content-only: diff --git a/docs/source/cpp_api/core.rst b/docs/source/cpp_api/core.rst index 8bf9051739..98365d6485 100644 --- a/docs/source/cpp_api/core.rst +++ b/docs/source/cpp_api/core.rst @@ -10,65 +10,13 @@ expose in public APIs. :language: c++ :class: highlight - -handle_t -######## - -#include - -.. doxygenclass:: raft::handle_t - :project: RAFT - :members: - - -Interruptible -############# - -``#include `` - -.. doxygenclass:: raft::interruptible - :project: RAFT - :members: - - -NVTX -#### - -``#include `` - -.. doxygennamespace:: raft::common::nvtx - :project: RAFT - :members: - :content-only: - - -Key-Value Pair -############## - -``#include `` - -.. doxygenstruct:: raft::KeyValuePair - :project: RAFT - :members: - - -logger -###### - -``#include `` - -.. doxygenclass:: raft::logger - :project: RAFT - :members: - - -Multi-node Multi-GPU -#################### - -``#include `` - -.. doxygennamespace:: raft::comms - :project: RAFT - :members: - :content-only: - +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + core_handle.rst + core_logger.rst + core_kvp.rst + core_nvtx.rst + core_interruptible.rst + core_operators.rst \ No newline at end of file diff --git a/docs/source/cpp_api/core_handle.rst b/docs/source/cpp_api/core_handle.rst new file mode 100644 index 0000000000..58fc80681e --- /dev/null +++ b/docs/source/cpp_api/core_handle.rst @@ -0,0 +1,15 @@ +handle_t +======== + +.. role:: py(code) + :language: c++ + :class: highlight + + +``#include `` + +namespace *raft::core* + +.. doxygenclass:: raft::handle_t + :project: RAFT + :members: diff --git a/docs/source/cpp_api/core_interruptible.rst b/docs/source/cpp_api/core_interruptible.rst new file mode 100644 index 0000000000..da767cdd6d --- /dev/null +++ b/docs/source/cpp_api/core_interruptible.rst @@ -0,0 +1,15 @@ +Interruptible +============= + +.. role:: py(code) + :language: c++ + :class: highlight + + +``#include `` + +namespace *raft::core* + +.. doxygenclass:: raft::interruptible + :project: RAFT + :members: diff --git a/docs/source/cpp_api/core_kvp.rst b/docs/source/cpp_api/core_kvp.rst new file mode 100644 index 0000000000..60a0da078b --- /dev/null +++ b/docs/source/cpp_api/core_kvp.rst @@ -0,0 +1,15 @@ +Key-Value Pair +============== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::core* + +.. doxygenstruct:: raft::KeyValuePair + :project: RAFT + :members: + diff --git a/docs/source/cpp_api/core_logger.rst b/docs/source/cpp_api/core_logger.rst new file mode 100644 index 0000000000..60714a63ea --- /dev/null +++ b/docs/source/cpp_api/core_logger.rst @@ -0,0 +1,15 @@ +logger +====== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::core* + +.. doxygenclass:: raft::logger + :project: RAFT + :members: + diff --git a/docs/source/cpp_api/core_nvtx.rst b/docs/source/cpp_api/core_nvtx.rst new file mode 100644 index 0000000000..addcbdda30 --- /dev/null +++ b/docs/source/cpp_api/core_nvtx.rst @@ -0,0 +1,17 @@ +NVTX +==== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::core* + +.. doxygennamespace:: raft::common::nvtx + :project: RAFT + :members: + :content-only: + + diff --git a/docs/source/cpp_api/core_operators.rst b/docs/source/cpp_api/core_operators.rst new file mode 100644 index 0000000000..be6443069d --- /dev/null +++ b/docs/source/cpp_api/core_operators.rst @@ -0,0 +1,16 @@ +Operators and Functors +====================== + +.. role:: py(code) + :language: c++ + :class: highlight + + +``#include `` + +namespace *raft::core* + +.. doxygengroup:: operators + :project: RAFT + :members: + :content-only: diff --git a/docs/source/cpp_api/distance.rst b/docs/source/cpp_api/distance.rst index 20b312a804..eb9bc6255d 100644 --- a/docs/source/cpp_api/distance.rst +++ b/docs/source/cpp_api/distance.rst @@ -9,7 +9,7 @@ distances have been highly optimized and support a wide assortment of different :class: highlight Distance Types -############## +-------------- ``#include `` @@ -19,28 +19,9 @@ namespace *raft::distance* :project: RAFT -Pairwise Distance -################# - -``#include `` - -namespace *raft::distance* - -.. doxygengroup:: distance_mdspan - :project: RAFT - :members: - :content-only: - - -Fused 1-Nearest Neighbors -######################### - -``#include `` - -namespace *raft::distance* - -.. doxygengroup:: fused_l2_nn - :project: RAFT - :members: - :content-only: +.. toctree:: + :maxdepth: 2 + :caption: Contents: + distance_pairwise.rst + distance_1nn.rst diff --git a/docs/source/cpp_api/distance_1nn.rst b/docs/source/cpp_api/distance_1nn.rst new file mode 100644 index 0000000000..8627069a2d --- /dev/null +++ b/docs/source/cpp_api/distance_1nn.rst @@ -0,0 +1,16 @@ +1-Nearest Neighbors +=================== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::distance* + +.. doxygengroup:: fused_l2_nn + :project: RAFT + :members: + :content-only: + diff --git a/docs/source/cpp_api/distance_pairwise.rst b/docs/source/cpp_api/distance_pairwise.rst new file mode 100644 index 0000000000..2a9c9a92f5 --- /dev/null +++ b/docs/source/cpp_api/distance_pairwise.rst @@ -0,0 +1,17 @@ +Pairwise Distance +================= + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::distance* + +.. doxygengroup:: distance_mdspan + :project: RAFT + :members: + :content-only: + + diff --git a/docs/source/cpp_api/linalg_arithmetic.rst b/docs/source/cpp_api/linalg_arithmetic.rst index 496c30a796..7bc428b9f0 100644 --- a/docs/source/cpp_api/linalg_arithmetic.rst +++ b/docs/source/cpp_api/linalg_arithmetic.rst @@ -1,11 +1,6 @@ Arithmetic ========== -This page provides C++ class references for the publicly-exposed elements of the `raft/linalg` (dense) linear algebra headers. -In addition to providing highly optimized arithmetic and matrix/vector operations, RAFT provides a consistent user experience -by providing common BLAS routines, standard linear system solvers, factorization and eigenvalue solvers. Some of these routines -hide the complexities of lower-level C-based libraries provided in the CUDA toolkit - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/linalg_blas.rst b/docs/source/cpp_api/linalg_blas.rst index 9dfd106ad9..12133e1dc5 100644 --- a/docs/source/cpp_api/linalg_blas.rst +++ b/docs/source/cpp_api/linalg_blas.rst @@ -1,11 +1,6 @@ BLAS Routines ============= -This page provides C++ class references for the publicly-exposed elements of the `raft/linalg` (dense) linear algebra headers. -In addition to providing highly optimized arithmetic and matrix/vector operations, RAFT provides a consistent user experience -by providing common BLAS routines, standard linear system solvers, factorization and eigenvalue solvers. Some of these routines -hide the complexities of lower-level C-based libraries provided in the CUDA toolkit - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/linalg_map_reduce.rst b/docs/source/cpp_api/linalg_map_reduce.rst index 64b2a8f519..5333a23f43 100644 --- a/docs/source/cpp_api/linalg_map_reduce.rst +++ b/docs/source/cpp_api/linalg_map_reduce.rst @@ -1,11 +1,6 @@ Mapping and Reduction ===================== -This page provides C++ class references for the publicly-exposed elements of the `raft/linalg` (dense) linear algebra headers. -In addition to providing highly optimized arithmetic and matrix/vector operations, RAFT provides a consistent user experience -by providing common BLAS routines, standard linear system solvers, factorization and eigenvalue solvers. Some of these routines -hide the complexities of lower-level C-based libraries provided in the CUDA toolkit - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/linalg_matrix.rst b/docs/source/cpp_api/linalg_matrix.rst index 983adf5898..e6024bcd02 100644 --- a/docs/source/cpp_api/linalg_matrix.rst +++ b/docs/source/cpp_api/linalg_matrix.rst @@ -1,11 +1,6 @@ Matrix Operations ================= -This page provides C++ class references for the publicly-exposed elements of the `raft/linalg` (dense) linear algebra headers. -In addition to providing highly optimized arithmetic and matrix/vector operations, RAFT provides a consistent user experience -by providing common BLAS routines, standard linear system solvers, factorization and eigenvalue solvers. Some of these routines -hide the complexities of lower-level C-based libraries provided in the CUDA toolkit - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/linalg_matrix_vector.rst b/docs/source/cpp_api/linalg_matrix_vector.rst index 72c696fe70..d92a3c9874 100644 --- a/docs/source/cpp_api/linalg_matrix_vector.rst +++ b/docs/source/cpp_api/linalg_matrix_vector.rst @@ -1,11 +1,6 @@ Matrix-Vector Operations ======================== -This page provides C++ class references for the publicly-exposed elements of the `raft/linalg` (dense) linear algebra headers. -In addition to providing highly optimized arithmetic and matrix/vector operations, RAFT provides a consistent user experience -by providing common BLAS routines, standard linear system solvers, factorization and eigenvalue solvers. Some of these routines -hide the complexities of lower-level C-based libraries provided in the CUDA toolkit - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/linalg_solver.rst b/docs/source/cpp_api/linalg_solver.rst index d11e5f7801..1a811e072a 100644 --- a/docs/source/cpp_api/linalg_solver.rst +++ b/docs/source/cpp_api/linalg_solver.rst @@ -1,11 +1,6 @@ Linear Algebra Solvers ====================== -This page provides C++ class references for the publicly-exposed elements of the `raft/linalg` (dense) linear algebra headers. -In addition to providing highly optimized arithmetic and matrix/vector operations, RAFT provides a consistent user experience -by providing common BLAS routines, standard linear system solvers, factorization and eigenvalue solvers. Some of these routines -hide the complexities of lower-level C-based libraries provided in the CUDA toolkit - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/matrix_arithmetic.rst b/docs/source/cpp_api/matrix_arithmetic.rst index c1fae55e83..4ed2a41680 100644 --- a/docs/source/cpp_api/matrix_arithmetic.rst +++ b/docs/source/cpp_api/matrix_arithmetic.rst @@ -1,9 +1,6 @@ Matrix Arithmetic ================= -This page provides C++ class references for the publicly-exposed elements of the `raft/matrix` headers. The `raft/matrix` -headers cover many operations on matrices that are otherwise not covered by `raft/linalg`. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/matrix_manipulation.rst b/docs/source/cpp_api/matrix_manipulation.rst index f976e3ccd1..d0da51e4b7 100644 --- a/docs/source/cpp_api/matrix_manipulation.rst +++ b/docs/source/cpp_api/matrix_manipulation.rst @@ -1,9 +1,6 @@ Matrix Manipulation =================== -This page provides C++ class references for the publicly-exposed elements of the `raft/matrix` headers. The `raft/matrix` -headers cover many operations on matrices that are otherwise not covered by `raft/linalg`. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/matrix_ordering.rst b/docs/source/cpp_api/matrix_ordering.rst index 52275ba5b4..fae6dc12a4 100644 --- a/docs/source/cpp_api/matrix_ordering.rst +++ b/docs/source/cpp_api/matrix_ordering.rst @@ -1,9 +1,6 @@ Matrix Ordering =============== -This page provides C++ class references for the publicly-exposed elements of the `raft/matrix` headers. The `raft/matrix` -headers cover many operations on matrices that are otherwise not covered by `raft/linalg`. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/matrix_reduction.rst b/docs/source/cpp_api/matrix_reduction.rst index fc1a1082aa..440a1528b4 100644 --- a/docs/source/cpp_api/matrix_reduction.rst +++ b/docs/source/cpp_api/matrix_reduction.rst @@ -1,9 +1,6 @@ Matrix Reductions ================= -This page provides C++ class references for the publicly-exposed elements of the `raft/matrix` headers. The `raft/matrix` -headers cover many operations on matrices that are otherwise not covered by `raft/linalg`. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/matrix_selection.rst b/docs/source/cpp_api/matrix_selection.rst index d58f1542ec..4842a75e0e 100644 --- a/docs/source/cpp_api/matrix_selection.rst +++ b/docs/source/cpp_api/matrix_selection.rst @@ -1,9 +1,6 @@ Matrix Selection ================ -This page provides C++ class references for the publicly-exposed elements of the `raft/matrix` headers. The `raft/matrix` -headers cover many operations on matrices that are otherwise not covered by `raft/linalg`. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/mdspan_mdarray.rst b/docs/source/cpp_api/mdspan_mdarray.rst index 2194060914..bf9e9e0139 100644 --- a/docs/source/cpp_api/mdspan_mdarray.rst +++ b/docs/source/cpp_api/mdspan_mdarray.rst @@ -1,8 +1,6 @@ mdarray: Multi-dimensional Owning Container =========================================== -This page provides C++ class references for the RAFT's 1d span and multi-dimension owning (mdarray) and non-owning (mdspan) APIs. These headers can be found in the `raft/core` directory. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/mdspan_mdspan.rst b/docs/source/cpp_api/mdspan_mdspan.rst index b34fda16bb..272a724833 100644 --- a/docs/source/cpp_api/mdspan_mdspan.rst +++ b/docs/source/cpp_api/mdspan_mdspan.rst @@ -1,8 +1,6 @@ mdspan: Multi-dimensional Non-owning View ========================================== -This page provides C++ class references for the RAFT's 1d span and multi-dimensional owning (mdarray) and non-owning (mdspan) APIs. These headers can be found in the `raft/core` directory. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/mdspan_representation.rst b/docs/source/cpp_api/mdspan_representation.rst index d71c23dcba..fbae03a3e0 100644 --- a/docs/source/cpp_api/mdspan_representation.rst +++ b/docs/source/cpp_api/mdspan_representation.rst @@ -1,8 +1,6 @@ Multi-dimensional Representation ================================ -This page provides C++ class references for the RAFT's 1d span and multi-dimension owning (mdarray) and non-owning (mdspan) APIs. These headers can be found in the `raft/core` directory. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/mdspan_span.rst b/docs/source/cpp_api/mdspan_span.rst index e633e38445..2bdaf4941e 100644 --- a/docs/source/cpp_api/mdspan_span.rst +++ b/docs/source/cpp_api/mdspan_span.rst @@ -1,8 +1,6 @@ span: One-dimensional Non-owning View ===================================== -This page provides C++ class references for the RAFT's 1d span and multi-dimension owning (mdarray) and non-owning (mdspan) APIs. These headers can be found in the `raft/core` directory. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/mnmg.rst b/docs/source/cpp_api/mnmg.rst new file mode 100644 index 0000000000..9543cbb4ee --- /dev/null +++ b/docs/source/cpp_api/mnmg.rst @@ -0,0 +1,50 @@ +Multi-node Multi-GPU +==================== + +RAFT contains C++ infrastructure for abstracting the communications layer when writing applications that scale on multiple nodes and across multiple GPUs. This infrastructure assumes OPG (one-process per GPU) architectures where multiple physical parallel units (processes, ranks, or workers) might be executing code concurrently but where each parallel unit is communicating with only a single GPU and is the only process communicating with each GPU. + +The comms layer in RAFT is intended to provide a facade API for barrier synchronous collective communications, allowing users to write algorithms using a single abstraction layer and deploy in many different types of systems. Currently, RAFT communications code has been deployed in MPI, Dask, and Spark clusters. + +.. role:: py(code) + :language: c++ + :class: highlight + +Common Types +------------ + +``#include `` + +namespace *raft::comms* + +.. doxygengroup:: comms_types + :project: RAFT + :members: + :content-only: + + +Comms Interface +--------------- + +.. doxygengroup:: comms_t + :project: RAFT + :members: + :content-only: + + +MPI Comms +--------- + +.. doxygengroup:: mpi_comms_factory + :project: RAFT + :members: + :content-only: + + +NCCL+UCX Comms +-------------- + +.. doxygengroup:: std_comms_factory + :project: RAFT + :members: + :content-only: + diff --git a/docs/source/cpp_api/neighbors.rst b/docs/source/cpp_api/neighbors.rst index afe3fd6263..9d2e762689 100644 --- a/docs/source/cpp_api/neighbors.rst +++ b/docs/source/cpp_api/neighbors.rst @@ -7,68 +7,12 @@ This page provides C++ class references for the publicly-exposed elements of the :language: c++ :class: highlight - -Brute-force ------------ - -``#include `` - -namespace *raft::neighbors::brute_force* - -.. doxygengroup:: brute_force_knn - :project: RAFT - :members: - :content-only: - - -IVF-Flat --------- - -``#include `` - -namespace *raft::neighbors::ivf_flat* - -.. doxygengroup:: ivf_flat - :project: RAFT - :members: - :content-only: - - -IVF-PQ --------- - -``#include `` - -namespace *raft::neighbors::ivf_pq* - -.. doxygengroup:: ivf_pq - :project: RAFT - :members: - :content-only: - -Epsilon Neighborhood --------------------- - -``#include `` - -namespace *raft::neighbors::epsilon_neighborhood* - -.. doxygengroup:: epsilon_neighbors - :project: RAFT - :members: - :content-only: - - -Random Ball Cover ------------------ - -``#include `` - -namespace *raft::neighbors::ball_cover* - -.. doxygengroup:: random_ball_cover - :project: RAFT - :members: - :content-only: - - +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + neighbors_brute_force.rst + neighbors_ivf_flat.rst + neighbors_ivf_pq.rst + neighbors_epsilon_neighborhood.rst + neighbors_ball_cover.rst \ No newline at end of file diff --git a/docs/source/cpp_api/neighbors_ball_cover.rst b/docs/source/cpp_api/neighbors_ball_cover.rst new file mode 100644 index 0000000000..85bd6b2d8e --- /dev/null +++ b/docs/source/cpp_api/neighbors_ball_cover.rst @@ -0,0 +1,17 @@ +Random Ball Cover +================= + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::neighbors::ball_cover* + +.. doxygengroup:: random_ball_cover + :project: RAFT + :members: + :content-only: + + diff --git a/docs/source/cpp_api/neighbors_brute_force.rst b/docs/source/cpp_api/neighbors_brute_force.rst new file mode 100644 index 0000000000..525addf428 --- /dev/null +++ b/docs/source/cpp_api/neighbors_brute_force.rst @@ -0,0 +1,18 @@ +Brute-Force +=========== + +.. role:: py(code) + :language: c++ + :class: highlight + + +``#include `` + +namespace *raft::neighbors::brute_force* + +.. doxygengroup:: brute_force_knn + :project: RAFT + :members: + :content-only: + + diff --git a/docs/source/cpp_api/neighbors_epsilon_neighborhood.rst b/docs/source/cpp_api/neighbors_epsilon_neighborhood.rst new file mode 100644 index 0000000000..f291a7605f --- /dev/null +++ b/docs/source/cpp_api/neighbors_epsilon_neighborhood.rst @@ -0,0 +1,15 @@ +Epsilon Neighborhood +==================== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::neighbors::epsilon_neighborhood* + +.. doxygengroup:: epsilon_neighbors + :project: RAFT + :members: + :content-only: diff --git a/docs/source/cpp_api/neighbors_ivf_flat.rst b/docs/source/cpp_api/neighbors_ivf_flat.rst new file mode 100644 index 0000000000..6f418fb165 --- /dev/null +++ b/docs/source/cpp_api/neighbors_ivf_flat.rst @@ -0,0 +1,18 @@ +IVF-Flat +======== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::neighbors::ivf_flat* + +.. doxygengroup:: ivf_flat + :project: RAFT + :members: + :content-only: + + + diff --git a/docs/source/cpp_api/neighbors_ivf_pq.rst b/docs/source/cpp_api/neighbors_ivf_pq.rst new file mode 100644 index 0000000000..d22ea6231f --- /dev/null +++ b/docs/source/cpp_api/neighbors_ivf_pq.rst @@ -0,0 +1,17 @@ +IVF-PQ +====== + +.. role:: py(code) + :language: c++ + :class: highlight + +``#include `` + +namespace *raft::neighbors::ivf_pq* + +.. doxygengroup:: ivf_pq + :project: RAFT + :members: + :content-only: + + diff --git a/docs/source/cpp_api/random.rst b/docs/source/cpp_api/random.rst index bfff60def5..9f5cdc7a74 100644 --- a/docs/source/cpp_api/random.rst +++ b/docs/source/cpp_api/random.rst @@ -24,5 +24,6 @@ namespace *raft::random* random_datagen.rst random_sampling_univariate.rst - random_samling_multivariable.rst + random_sampling_multivariable.rst random_sampling_without_replacement.rst + diff --git a/docs/source/cpp_api/random_datagen.rst b/docs/source/cpp_api/random_datagen.rst index 0075f1b076..ec23845b6b 100644 --- a/docs/source/cpp_api/random_datagen.rst +++ b/docs/source/cpp_api/random_datagen.rst @@ -1,8 +1,6 @@ Data Generation =============== -This page provides C++ class references for the publicly-exposed elements of the random package. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/random_sampling_multivariable.rst b/docs/source/cpp_api/random_sampling_multivariable.rst index 39950285d0..166043b632 100644 --- a/docs/source/cpp_api/random_sampling_multivariable.rst +++ b/docs/source/cpp_api/random_sampling_multivariable.rst @@ -1,8 +1,6 @@ Multi-Variable Random Sampling ============================== -This page provides C++ class references for the publicly-exposed elements of the random package. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/random_sampling_univariate.rst b/docs/source/cpp_api/random_sampling_univariate.rst index 7a08f58dad..ffa58a0d3a 100644 --- a/docs/source/cpp_api/random_sampling_univariate.rst +++ b/docs/source/cpp_api/random_sampling_univariate.rst @@ -1,8 +1,6 @@ Univariate Random Sampling ========================== -This page provides C++ class references for the publicly-exposed elements of the random package. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/random_sampling_without_replacement.rst b/docs/source/cpp_api/random_sampling_without_replacement.rst index ea78d312b7..ac0d3bea86 100644 --- a/docs/source/cpp_api/random_sampling_without_replacement.rst +++ b/docs/source/cpp_api/random_sampling_without_replacement.rst @@ -1,8 +1,6 @@ Sampling Without Replacement ============================ -This page provides C++ class references for the publicly-exposed elements of the random package. - .. role:: py(code) :language: c++ :class: highlight diff --git a/docs/source/cpp_api/stats.rst b/docs/source/cpp_api/stats.rst index e96f627f81..fd23ce2149 100644 --- a/docs/source/cpp_api/stats.rst +++ b/docs/source/cpp_api/stats.rst @@ -7,319 +7,13 @@ This page provides C++ class references for the publicly-exposed elements of the :language: c++ :class: highlight -Summary Statistics -################## - -Covariance ----------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_cov - :project: RAFT - :members: - :content-only: - -Histogram ---------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_histogram - :project: RAFT - :members: - :content-only: - -Mean ----- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_mean - :project: RAFT - :members: - :content-only: - -Mean Center ------------ - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_mean_center - :project: RAFT - :members: - :content-only: - -Mean Variance -------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_mean_var - :project: RAFT - :members: - :content-only: - -Min/Max -------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_minmax - :project: RAFT - :members: - :content-only: - -Standard Deviation ------------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_stddev - :project: RAFT - :members: - :content-only: - -Sum ---- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_sum - :project: RAFT - :members: - :content-only: - -Weighted Average ----------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_weighted_mean - :project: RAFT - :members: - :content-only: - - -Information Theory & Probability -################################ - -Contingency Matrix ------------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: contingency_matrix - :project: RAFT - :members: - :content-only: - -Entropy -------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_entropy - :project: RAFT - :members: - :content-only: - - -KL-Divergence -------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: kl_divergence - :project: RAFT - :members: - :content-only: - -Mutual Information ------------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_mutual_info - :project: RAFT - :members: - :content-only: - - -Regression Model Scoring -######################## - -Information Criterion ---------------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_information_criterion - :project: RAFT - :members: - :content-only: - -R2 Score --------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_r2_score - :project: RAFT - :members: - :content-only: - - -Regression Metrics ------------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_regression_metrics - :project: RAFT - :members: - :content-only: - - -Classification Model Scoring -############################ - -Accuracy --------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_accuracy - :project: RAFT - :members: - :content-only: - - -Clustering Model Scoring -######################## - -Adjusted Rand Index -------------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_adj_rand_index - :project: RAFT - :members: - :content-only: - -Completeness Score ------------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_completeness - :project: RAFT - :members: - :content-only: - -Cluster Dispersion ------------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_cluster_dispersion - :project: RAFT - :members: - :content-only: - - -Rand Index ----------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_rand_index - :project: RAFT - :members: - :content-only: - -Silhouette Score ----------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_silhouette_score - :project: RAFT - :members: - :content-only: - - -V Measure ---------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_vmeasure - :project: RAFT - :members: - :content-only: - - - - -Neighborhood Model Scoring -########################## - -Trustworthiness ---------------- - -``#include `` - -namespace *raft::stats* - -.. doxygengroup:: stats_trustworthiness - :project: RAFT - :members: - :content-only: +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + stats_summary.rst + stats_probability.rst + stats_regression.rst + stats_classification.rst + stats_clustering.rst + stats_neighborhood.rst diff --git a/docs/source/cpp_api/stats_classification.rst b/docs/source/cpp_api/stats_classification.rst new file mode 100644 index 0000000000..929d2808f3 --- /dev/null +++ b/docs/source/cpp_api/stats_classification.rst @@ -0,0 +1,20 @@ +Classification Model Scoring +============================ + +.. role:: py(code) + :language: c++ + :class: highlight + + +Accuracy +-------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_accuracy + :project: RAFT + :members: + :content-only: + diff --git a/docs/source/cpp_api/stats_clustering.rst b/docs/source/cpp_api/stats_clustering.rst new file mode 100644 index 0000000000..0ab96cf1f5 --- /dev/null +++ b/docs/source/cpp_api/stats_clustering.rst @@ -0,0 +1,81 @@ +Clustering Model Scoring +======================== + +.. role:: py(code) + :language: c++ + :class: highlight + + +Adjusted Rand Index +------------------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_adj_rand_index + :project: RAFT + :members: + :content-only: + +Completeness Score +------------------ + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_completeness + :project: RAFT + :members: + :content-only: + +Cluster Dispersion +------------------ + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_cluster_dispersion + :project: RAFT + :members: + :content-only: + + +Rand Index +---------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_rand_index + :project: RAFT + :members: + :content-only: + +Silhouette Score +---------------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_silhouette_score + :project: RAFT + :members: + :content-only: + + +V Measure +--------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_vmeasure + :project: RAFT + :members: + :content-only: diff --git a/docs/source/cpp_api/stats_neighborhood.rst b/docs/source/cpp_api/stats_neighborhood.rst new file mode 100644 index 0000000000..f80e349c3b --- /dev/null +++ b/docs/source/cpp_api/stats_neighborhood.rst @@ -0,0 +1,18 @@ +Neighborhood Model Scoring +========================== + +.. role:: py(code) + :language: c++ + :class: highlight + +Trustworthiness +--------------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_trustworthiness + :project: RAFT + :members: + :content-only: diff --git a/docs/source/cpp_api/stats_probability.rst b/docs/source/cpp_api/stats_probability.rst new file mode 100644 index 0000000000..457879d87c --- /dev/null +++ b/docs/source/cpp_api/stats_probability.rst @@ -0,0 +1,56 @@ +Probability & Information Theory +================================ + +.. role:: py(code) + :language: c++ + :class: highlight + +Contingency Matrix +------------------ + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: contingency_matrix + :project: RAFT + :members: + :content-only: + +Entropy +------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_entropy + :project: RAFT + :members: + :content-only: + + +KL-Divergence +------------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: kl_divergence + :project: RAFT + :members: + :content-only: + +Mutual Information +------------------ + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_mutual_info + :project: RAFT + :members: + :content-only: + diff --git a/docs/source/cpp_api/stats_regression.rst b/docs/source/cpp_api/stats_regression.rst new file mode 100644 index 0000000000..8c172b441d --- /dev/null +++ b/docs/source/cpp_api/stats_regression.rst @@ -0,0 +1,45 @@ +Regression Model Scoring +======================== + +.. role:: py(code) + :language: c++ + :class: highlight + +Information Criterion +--------------------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_information_criterion + :project: RAFT + :members: + :content-only: + +R2 Score +-------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_r2_score + :project: RAFT + :members: + :content-only: + + +Regression Metrics +------------------ + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_regression_metrics + :project: RAFT + :members: + :content-only: + + diff --git a/docs/source/cpp_api/stats_summary.rst b/docs/source/cpp_api/stats_summary.rst new file mode 100644 index 0000000000..7b4bf6a801 --- /dev/null +++ b/docs/source/cpp_api/stats_summary.rst @@ -0,0 +1,114 @@ +Summary Statistics +================== + +.. role:: py(code) + :language: c++ + :class: highlight + +Covariance +---------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_cov + :project: RAFT + :members: + :content-only: + +Histogram +--------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_histogram + :project: RAFT + :members: + :content-only: + +Mean +---- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_mean + :project: RAFT + :members: + :content-only: + +Mean Center +----------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_mean_center + :project: RAFT + :members: + :content-only: + +Mean Variance +------------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_mean_var + :project: RAFT + :members: + :content-only: + +Min/Max +------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_minmax + :project: RAFT + :members: + :content-only: + +Standard Deviation +------------------ + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_stddev + :project: RAFT + :members: + :content-only: + +Sum +--- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_sum + :project: RAFT + :members: + :content-only: + +Weighted Average +---------------- + +``#include `` + +namespace *raft::stats* + +.. doxygengroup:: stats_weighted_mean + :project: RAFT + :members: + :content-only: diff --git a/docs/source/pylibraft_api.rst b/docs/source/pylibraft_api.rst index d6bda89c21..84955283cb 100644 --- a/docs/source/pylibraft_api.rst +++ b/docs/source/pylibraft_api.rst @@ -1,6 +1,6 @@ -~~~~~~~~~~~~~ -PyLibRAFT API -~~~~~~~~~~~~~ +~~~~~~~~~~ +Python API +~~~~~~~~~~ .. _api: diff --git a/docs/source/pylibraft_api/cluster.rst b/docs/source/pylibraft_api/cluster.rst index c70fd46b2a..59e53e7d4c 100644 --- a/docs/source/pylibraft_api/cluster.rst +++ b/docs/source/pylibraft_api/cluster.rst @@ -7,7 +7,15 @@ This page provides pylibraft class references for the publicly-exposed elements :language: python :class: highlight +.. autoclass:: pylibraft.cluster.kmeans.KMeansParams + :members: + +.. autofunction:: pylibraft.cluster.kmeans.fit + +.. autofunction:: pylibraft.cluster.kmeans.cluster_cost + .. autofunction:: pylibraft.cluster.compute_new_centroids + diff --git a/docs/source/pylibraft_api/neighbors.rst b/docs/source/pylibraft_api/neighbors.rst index 14046fa97a..89bb577027 100644 --- a/docs/source/pylibraft_api/neighbors.rst +++ b/docs/source/pylibraft_api/neighbors.rst @@ -7,6 +7,10 @@ This page provides pylibraft class references for the publicly-exposed elements :language: python :class: highlight + +IVF-PQ +###### + .. autoclass:: pylibraft.neighbors.ivf_pq.IndexParams :members: @@ -18,3 +22,9 @@ This page provides pylibraft class references for the publicly-exposed elements :members: .. autofunction:: pylibraft.neighbors.ivf_pq.search + + +Candidate Refinement +#################### + +.. autofunction:: pylibraft.neighbors.refine diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index d8cc5ce08b..e955706dc4 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -8,9 +8,9 @@ RAFT relies heavily on the [RMM](https://github.com/rapidsai/rmm) library which ## Multi-dimensional Spans and Arrays -The APIs in RAFT currently accept raw pointers to device memory and we are in the process of simplifying the APIs with the [mdspan](https://arxiv.org/abs/2010.06474) multi-dimensional array view for representing data in higher dimensions similar to the `ndarray` in the Numpy Python library. RAFT also contains the corresponding owning `mdarray` structure, which simplifies the allocation and management of multi-dimensional data in both host and device (GPU) memory. +Most of the APIs in RAFT accept [mdspan](https://arxiv.org/abs/2010.06474) multi-dimensional array view for representing data in higher dimensions similar to the `ndarray` in the Numpy Python library. RAFT also contains the corresponding owning `mdarray` structure, which simplifies the allocation and management of multi-dimensional data in both host and device (GPU) memory. -The `mdarray` forms a convenience layer over RMM and can be constructed in RAFT using a number of different helper functions: +The `mdarray` is an owning object that forms a convenience layer over RMM and can be constructed in RAFT using a number of different helper functions: ```c++ #include @@ -118,11 +118,11 @@ auto metric = raft::distance::DistanceType::L2SqrtExpanded; raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); ``` -## Python Example +### Python Example -The `pylibraft` package contains a Python API for RAFT algorithms and primitives. `pylibraft` integrates nicely into other libraries by being very lightweight with minimal dependencies and accepting any object that supports the `__cuda_array_interface__`, such as [CuPy's ndarray](https://docs.cupy.dev/en/stable/user_guide/interoperability.html#rmm). The package is currently limited to pairwise distances and RMAT graph generation, but we will continue adding more in future releases. +The `pylibraft` package contains a Python API for RAFT algorithms and primitives. `pylibraft` integrates nicely into other libraries by being very lightweight with minimal dependencies and accepting any object that supports the `__cuda_array_interface__`, such as [CuPy's ndarray](https://docs.cupy.dev/en/stable/user_guide/interoperability.html#rmm). The number of RAFT algorithms exposed in this package is continuing to grow from release to release. -The example below demonstrates computing the pairwise Euclidean distances between CuPy arrays. `pylibraft` is a low-level API that prioritizes efficiency and simplicity over being pythonic, which is shown here by pre-allocating the output memory before invoking the `pairwise_distance` function. Note that CuPy is not a required dependency for `pylibraft`. +The example below demonstrates computing the pairwise Euclidean distances between CuPy arrays. Note that CuPy is not a required dependency for `pylibraft`. ```python import cupy as cp @@ -137,3 +137,47 @@ in2 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32) output = pairwise_distance(in1, in2, metric="euclidean") ``` + +The `output` array in the above example is of type `raft.common.device_ndarray`, which supports [__cuda_array_interface__](https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html#cuda-array-interface-version-2) making it interoperable with other libraries like CuPy, Numba, and PyTorch that also support it. CuPy supports DLPack, which also enables zero-copy conversion from `raft.common.device_ndarray` to JAX and Tensorflow. + +Below is an example of converting the output `pylibraft.common.device_ndarray` to a CuPy array: +```python +cupy_array = cp.asarray(output) +``` + +And converting to a PyTorch tensor: +```python +import torch + +torch_tensor = torch.as_tensor(output, device='cuda') +``` + +When the corresponding library has been installed and available in your environment, this conversion can also be done automatically by all RAFT compute APIs by setting a global configuration option: +```python +import pylibraft.config +pylibraft.config.set_output_as("cupy") # All compute APIs will return cupy arrays +pylibraft.config.set_output_as("torch") # All compute APIs will return torch tensors +``` + +You can also specify a `callable` that accepts a `pylibraft.common.device_ndarray` and performs a custom conversion. The following example converts all output to `numpy` arrays: +```python +pylibraft.config.set_output_as(lambda device_ndarray: return device_ndarray.copy_to_host()) +``` + + +`pylibraft` also supports writing to a pre-allocated output array so any `__cuda_array_interface__` supported array can be written to in-place: + +```python +import cupy as cp + +from pylibraft.distance import pairwise_distance + +n_samples = 5000 +n_features = 50 + +in1 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32) +in2 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32) +output = cp.empty((n_samples, n_samples), dtype=cp.float32) + +pairwise_distance(in1, in2, out=output, metric="euclidean") +``` diff --git a/docs/source/raft_dask_api.rst b/docs/source/raft_dask_api.rst index 10ba8781a2..44720c188c 100644 --- a/docs/source/raft_dask_api.rst +++ b/docs/source/raft_dask_api.rst @@ -1,6 +1,6 @@ -~~~~~~~~~~~~~~~~~~~~~~~ -RAFT Dask API Reference -~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~ +RAFT Dask API +~~~~~~~~~~~~~ .. role:: py(code) :language: python diff --git a/python/pylibraft/pylibraft/__init__.py b/python/pylibraft/pylibraft/__init__.py index 1124c64102..c1a5bf1663 100644 --- a/python/pylibraft/pylibraft/__init__.py +++ b/python/pylibraft/pylibraft/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. # +import pylibraft.config from pylibraft._version import get_versions __version__ = get_versions()["version"] diff --git a/python/pylibraft/pylibraft/cluster/kmeans.pyx b/python/pylibraft/pylibraft/cluster/kmeans.pyx index 9097eccfa8..f2e010f6a5 100644 --- a/python/pylibraft/pylibraft/cluster/kmeans.pyx +++ b/python/pylibraft/pylibraft/cluster/kmeans.pyx @@ -45,8 +45,11 @@ from pylibraft.common.cpp.mdspan cimport * from pylibraft.common.cpp.optional cimport optional from pylibraft.common.handle cimport handle_t +from pylibraft.common import auto_convert_output + @auto_sync_handle +@auto_convert_output def compute_new_centroids(X, centroids, labels, @@ -197,6 +200,7 @@ def compute_new_centroids(X, @auto_sync_handle +@auto_convert_output def cluster_cost(X, centroids, handle=None): """ Compute cluster cost given an input matrix and existing centroids @@ -403,6 +407,7 @@ FitOutput = namedtuple("FitOutput", "centroids inertia n_iter") @auto_sync_handle +@auto_convert_output def fit( KMeansParams params, X, centroids=None, sample_weights=None, handle=None ): diff --git a/python/pylibraft/pylibraft/common/__init__.py b/python/pylibraft/pylibraft/common/__init__.py index 9c0f631b86..4f87720030 100644 --- a/python/pylibraft/pylibraft/common/__init__.py +++ b/python/pylibraft/pylibraft/common/__init__.py @@ -17,5 +17,6 @@ from .cuda import Stream from .device_ndarray import device_ndarray from .handle import Handle +from .outputs import auto_convert_output __all__ = ["Handle", "Stream"] diff --git a/python/pylibraft/pylibraft/common/outputs.py b/python/pylibraft/pylibraft/common/outputs.py new file mode 100644 index 0000000000..e5b08e1798 --- /dev/null +++ b/python/pylibraft/pylibraft/common/outputs.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +import warnings + +import pylibraft.config + + +def import_warn_(lib): + warnings.warn( + "%s is not available and output cannot be converted." + "Returning original output instead." % lib + ) + + +def convert_to_torch(device_ndarray): + try: + import torch + + return torch.as_tensor(device_ndarray, device="cuda") + except ImportError: + import_warn_("PyTorch") + return device_ndarray + + +def convert_to_cupy(device_ndarray): + try: + import cupy + + return cupy.asarray(device_ndarray) + except ImportError: + import_warn_("CuPy") + return device_ndarray + + +def no_conversion(device_ndarray): + return device_ndarray + + +def convert_to_cai_type(device_ndarray): + output_as_ = pylibraft.config.output_as_ + if callable(output_as_): + return output_as_(device_ndarray) + elif output_as_ == "raft": + return device_ndarray + elif output_as_ == "torch": + return convert_to_torch(device_ndarray) + elif output_as_ == "cupy": + return convert_to_cupy(device_ndarray) + else: + raise ValueError("No valid type conversion found for %s" % output_as_) + + +def conv(ret): + for i in ret: + if isinstance(i, pylibraft.common.device_ndarray): + yield convert_to_cai_type(i) + else: + yield i + + +def auto_convert_output(f): + """Decorator to automatically convert an output device_ndarray + (or list or tuple of device_ndarray) into the configured + `__cuda_array_interface__` compliant type. + """ + + @functools.wraps(f) + def wrapper(*args, **kwargs): + ret_value = f(*args, **kwargs) + if isinstance(ret_value, pylibraft.common.device_ndarray): + return convert_to_cai_type(ret_value) + elif isinstance(ret_value, tuple): + return tuple(conv(ret_value)) + elif isinstance(ret_value, list): + return list(conv(ret_value)) + else: + return ret_value + + return wrapper diff --git a/python/pylibraft/pylibraft/config.py b/python/pylibraft/pylibraft/config.py new file mode 100644 index 0000000000..c173bca2bd --- /dev/null +++ b/python/pylibraft/pylibraft/config.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +SUPPORTED_OUTPUT_TYPES = ["torch", "cupy", "raft"] + +output_as_ = "raft" # By default, return device_ndarray from functions + + +def set_output_as(output): + """ + Set output format for RAFT functions. + + Calling this function will change the output type of RAFT functions. + By default RAFT returns a `pylibraft.common.device_ndarray` for arrays + on GPU memory. Calling `set_output_as` allows you to have RAFT return + arrays as cupy arrays or pytorch tensors instead. You can also have + RAFT convert the output to other frameworks by passing a callable to + do the conversion here. + + Notes + ----- + Returning arrays in cupy or torch format requires you to install + cupy or torch. + + Parameters + ---------- + output : { "raft", "cupy", "torch" } or callable + The output format to convert to. Can either be a str describing the + framework to convert to, or a callable that accepts a + device_ndarray and returns the converted type. + """ + if output not in SUPPORTED_OUTPUT_TYPES and not callable(output): + raise ValueError("Unsupported output option " % output) + global output_as_ + output_as_ = output diff --git a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx index a21fe46fa3..ce8e656822 100644 --- a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx +++ b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx @@ -26,7 +26,12 @@ from libcpp cimport bool from .distance_type cimport DistanceType -from pylibraft.common import Handle, cai_wrapper, device_ndarray +from pylibraft.common import ( + Handle, + auto_convert_output, + cai_wrapper, + device_ndarray, +) from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t @@ -57,6 +62,7 @@ cdef extern from "raft_runtime/distance/fused_l2_nn.hpp" \ @auto_sync_handle +@auto_convert_output def fused_l2_nn_argmin(X, Y, out=None, sqrt=True, handle=None): """ Compute the 1-nearest neighbors between X and Y using the L2 distance diff --git a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx index 6f7a135951..2ed2b8ed57 100644 --- a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx +++ b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx @@ -31,7 +31,7 @@ from pylibraft.common.handle import auto_sync_handle from pylibraft.common.handle cimport handle_t -from pylibraft.common import cai_wrapper, device_ndarray +from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray cdef extern from "raft_runtime/distance/pairwise_distance.hpp" \ @@ -89,6 +89,7 @@ SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product", @auto_sync_handle +@auto_convert_output def distance(X, Y, out=None, metric="euclidean", p=2.0, handle=None): """ Compute pairwise distances between X and Y diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index fdc8d1755c..6ad9b753b3 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -33,7 +33,12 @@ from libcpp cimport bool, nullptr from pylibraft.distance.distance_type cimport DistanceType -from pylibraft.common import Handle, cai_wrapper, device_ndarray +from pylibraft.common import ( + Handle, + auto_convert_output, + cai_wrapper, + device_ndarray, +) from pylibraft.common.interruptible import cuda_interruptible from pylibraft.common.handle cimport handle_t @@ -302,6 +307,7 @@ cdef class Index: @auto_sync_handle +@auto_convert_output def build(IndexParams index_params, dataset, handle=None): """ Builds an IVF-PQ index that can be later used for nearest neighbor search. @@ -401,6 +407,7 @@ def build(IndexParams index_params, dataset, handle=None): @auto_sync_handle +@auto_convert_output def extend(Index index, new_vectors, new_indices, handle=None): """ Extend an existing index with new vectors. @@ -565,6 +572,7 @@ cdef class SearchParams: @auto_sync_handle +@auto_convert_output def search(SearchParams search_params, Index index, queries, diff --git a/python/pylibraft/pylibraft/neighbors/refine.pyx b/python/pylibraft/pylibraft/neighbors/refine.pyx index ca328c1cd5..37ef69e7b5 100644 --- a/python/pylibraft/pylibraft/neighbors/refine.pyx +++ b/python/pylibraft/pylibraft/neighbors/refine.pyx @@ -33,7 +33,12 @@ from libcpp cimport bool, nullptr from pylibraft.distance.distance_type cimport DistanceType -from pylibraft.common import Handle, cai_wrapper, device_ndarray +from pylibraft.common import ( + Handle, + auto_convert_output, + cai_wrapper, + device_ndarray, +) from pylibraft.common.handle cimport handle_t @@ -208,6 +213,7 @@ cdef host_matrix_view[int8_t, uint64_t, row_major] \ @auto_sync_handle +@auto_convert_output def refine(dataset, queries, candidates, k=None, indices=None, distances=None, metric="l2_expanded", handle=None): """ diff --git a/python/pylibraft/pylibraft/test/test_config.py b/python/pylibraft/pylibraft/test/test_config.py new file mode 100644 index 0000000000..27a697d388 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_config.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest + +try: + import cupy +except ImportError: + pytest.skip(reason="cupy not installed.") + +import pylibraft.config +from pylibraft.common import auto_convert_output, device_ndarray + + +@auto_convert_output +def gen_cai(m, n, t=None): + if t is None: + return device_ndarray.empty((m, n)) + elif t == tuple: + return device_ndarray.empty((m, n)), device_ndarray.empty((m, n)) + elif t == list: + return [device_ndarray.empty((m, n)), device_ndarray.empty((m, n))] + + +@pytest.mark.parametrize( + "out_type", + [ + ["cupy", cupy.ndarray], + ["raft", pylibraft.common.device_ndarray], + [lambda arr: arr.copy_to_host(), np.ndarray], + ], +) +@pytest.mark.parametrize("gen_t", [None, tuple, list]) +def test_auto_convert_output(out_type, gen_t): + + conf, t = out_type + pylibraft.config.set_output_as(conf) + + output = gen_cai(1, 5, gen_t) + + if not isinstance(output, (list, tuple)): + assert isinstance(output, t) + + else: + for o in output: + assert isinstance(o, t) + + # Make sure we set the config back to default + pylibraft.config.set_output_as("raft")