diff --git a/BUILD.md b/BUILD.md
index c94bb24204..d38db90249 100644
--- a/BUILD.md
+++ b/BUILD.md
@@ -101,7 +101,7 @@ For example, to run the distance tests:
It can take sometime to compile all of the tests. You can build individual tests by providing a semicolon-separated list to the `--limit-tests` option in `build.sh`:
```bash
-./build.sh libraft tests --limit-tests=SPATIAL_TEST;DISTANCE_TEST;MATRIX_TEST
+./build.sh libraft tests --limit-tests=NEIGHBORS_TEST;DISTANCE_TEST;MATRIX_TEST
```
### Benchmarks
@@ -111,10 +111,10 @@ The benchmarks are broken apart by algorithm category, so you will find several
./build.sh libraft bench
```
-It can take sometime to compile all of the tests. You can build individual tests by providing a semicolon-separated list to the `--limit-tests` option in `build.sh`:
+It can take sometime to compile all of the benchmarks. You can build individual benchmarks by providing a semicolon-separated list to the `--limit-bench` option in `build.sh`:
```bash
-./build.sh libraft bench --limit-bench=SPATIAL_BENCH;DISTANCE_BENCH;LINALG_BENCH
+./build.sh libraft bench --limit-bench=NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH
```
### C++ Using Cmake
diff --git a/README.md b/README.md
index 2c0231f37e..cc32e4d404 100755
--- a/README.md
+++ b/README.md
@@ -12,19 +12,19 @@ While not exhaustive, the following general categories help summarize the accele
| Category | Examples |
| --- | --- |
| **Data Formats** | sparse & dense, conversions, data generation |
-| **Dense Linear Algebra** | matrix arithmetic, norms, factorization, least squares, svd & eigenvalue problems |
+| **Dense Operations** | linear algebra, matrix and vector operations, slicing, norms, factorization, least squares, svd & eigenvalue problems |
+| **Sparse Operations** | linear algebra, eigenvalue problems, slicing, symmetrization, components & labeling |
| **Spatial** | pairwise distances, nearest neighbors, neighborhood graph construction |
-| **Sparse Operations** | linear algebra, eigenvalue problems, slicing, symmetrization, labeling |
| **Basic Clustering** | spectral clustering, hierarchical clustering, k-means |
| **Solvers** | combinatorial optimization, iterative solvers |
| **Statistics** | sampling, moments and summary statistics, metrics |
-| **Distributed Tools** | multi-node multi-gpu infrastructure |
+| **Tools & Utilities** | common utilities for developing CUDA applications, multi-node multi-gpu infrastructure |
RAFT provides a header-only C++ library and pre-compiled shared libraries that can 1) speed up compile times and 2) enable the APIs to be used without CUDA-enabled compilers.
-RAFT also provides 2 Python libraries:
-- `pylibraft` - low-level Python wrappers around RAFT algorithms and primitives.
-- `raft-dask` - reusable infrastructure for building analytics, including tools for building both single-GPU and multi-node multi-GPU algorithms.
+In addition to the C++ library, RAFT also provides 2 Python libraries:
+- `pylibraft` - lightweight low-level Python wrappers around RAFT algorithms and primitives.
+- `raft-dask` - multi-node multi-GPU communicator infrastructure for building distributed algorithms on the GPU with Dask.
## Getting started
@@ -78,9 +78,9 @@ raft::distance::pairwise_distance(handle, input.view(), input.view(), output.vie
### Python Example
-The `pylibraft` package contains a Python API for RAFT algorithms and primitives. The package is currently limited to pairwise distances, and we will continue adding more.
+The `pylibraft` package contains a Python API for RAFT algorithms and primitives. `pylibraft` integrates nicely into other libraries by being very lightweight with minimal dependencies and accepting any object that supports the `__cuda_array_interface__`, such as [CuPy's ndarray](https://docs.cupy.dev/en/stable/user_guide/interoperability.html#rmm). The package is currently limited to pairwise distances and RMAT graph generation, but we will continue adding more in future releases.
-The example below demonstrates computing the pairwise Euclidean distances between cupy arrays. `pylibraft` is a low-level API that prioritizes efficiency and simplicity over being pythonic, which is shown here by pre-allocating the output memory before invoking the `pairwise_distance` function.
+The example below demonstrates computing the pairwise Euclidean distances between CuPy arrays. `pylibraft` is a low-level API that prioritizes efficiency and simplicity over being pythonic, which is shown here by pre-allocating the output memory before invoking the `pairwise_distance` function. Note that CuPy is not a required dependency for `pylibraft`.
```python
import cupy as cp
@@ -107,7 +107,7 @@ The easiest way to install RAFT is through conda and several packages are provid
- `libraft-headers` RAFT headers
- `libraft-nn` (optional) contains shared libraries for the nearest neighbors primitives.
- `libraft-distance` (optional) contains shared libraries for distance primitives.
-- `pylibraft` (optional) Python wrappers around RAFT algorithms and primitives
+- `pylibraft` (optional) Python wrappers around RAFT algorithms and primitives.
- `raft-dask` (optional) enables deployment of multi-node multi-GPU algorithms that use RAFT `raft::comms` in Dask clusters.
Use the following command to install all of the RAFT packages with conda (replace `rapidsai` with `rapidsai-nightly` to install more up-to-date but less stable nightly packages). `mamba` is preferred over the `conda` command.
@@ -198,7 +198,25 @@ The folder structure mirrors other RAPIDS repos, with the following folders:
- `bench`: Benchmarks source code
- `cmake`: Cmake modules and templates
- `doxygen`: Doxygen configuration
- - `include`: The C++ API headers are fully-contained here
+ - `include`: The C++ API headers are fully-contained here (deprecated directories are excluded from the listing below)
+ - `cluster`: Basic clustering primitives and algorithms.
+ - `comms`: A multi-node multi-GPU communications abstraction layer for NCCL+UCX and MPI+NCCL, which can be deployed in Dask clusters using the `raft-dask` Python package.
+ - `core`: Core API headers which require minimal dependencies aside from RMM and Cudatoolkit. These are safe to expose on public APIs and do not require `nvcc` to build. This is the same for any headers in RAFT which have the suffix `*_types.hpp`.
+ - `distance`: Distance primitives
+ - `linalg`: Dense linear algebra
+ - `matrix`: Dense matrix operations
+ - `neighbors`: Nearest neighbors and knn graph construction
+ - `random`: Random number generation, sampling, and data generation primitives
+ - `solver`: Iterative and combinatorial solvers for optimization and approximation
+ - `sparse`: Sparse matrix operations
+ - `convert`: Sparse conversion functions
+ - `distance`: Sparse distance computations
+ - `linalg`: Sparse linear algebra
+ - `neighbors`: Sparse nearest neighbors and knn graph construction
+ - `op`: Various sparse operations such as slicing and filtering (Note: this will soon be renamed to `sparse/matrix`)
+ - `solver`: Sparse solvers for optimization and approximation
+ - `stats`: Moments, summary statistics, model performance measures
+ - `util`: Various reusable tools and utilities for accelerated algorithm development
- `scripts`: Helpful scripts for development
- `src`: Compiled APIs and template specializations for the shared libraries
- `test`: Googletests source code
diff --git a/build.sh b/build.sh
index d1dd8bdde1..9548fbec44 100755
--- a/build.sh
+++ b/build.sh
@@ -40,8 +40,8 @@ HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=
-#include
-#include
+#include
+#include
#include
#if defined RAFT_DISTANCE_COMPILED
@@ -143,16 +143,16 @@ template
struct ivf_flat_knn {
using dist_t = float;
- std::optional> index;
- raft::spatial::knn::ivf_flat::index_params index_params;
- raft::spatial::knn::ivf_flat::search_params search_params;
+ std::optional> index;
+ raft::neighbors::ivf_flat::index_params index_params;
+ raft::neighbors::ivf_flat::search_params search_params;
params ps;
ivf_flat_knn(const raft::handle_t& handle, const params& ps, const ValT* data) : ps(ps)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
- index.emplace(raft::spatial::knn::ivf_flat::build(
+ index.emplace(raft::neighbors::ivf_flat::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
}
@@ -162,7 +162,7 @@ struct ivf_flat_knn {
IdxT* out_idxs)
{
search_params.n_probes = 20;
- raft::spatial::knn::ivf_flat::search(
+ raft::neighbors::ivf_flat::search(
handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists);
}
};
@@ -171,16 +171,16 @@ template
struct ivf_pq_knn {
using dist_t = float;
- std::optional> index;
- raft::spatial::knn::ivf_pq::index_params index_params;
- raft::spatial::knn::ivf_pq::search_params search_params;
+ std::optional> index;
+ raft::neighbors::ivf_pq::index_params index_params;
+ raft::neighbors::ivf_pq::search_params search_params;
params ps;
ivf_pq_knn(const raft::handle_t& handle, const params& ps, const ValT* data) : ps(ps)
{
index_params.n_lists = 4096;
index_params.metric = raft::distance::DistanceType::L2Expanded;
- index.emplace(raft::spatial::knn::ivf_pq::build(
+ index.emplace(raft::neighbors::ivf_pq::build(
handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims)));
}
@@ -190,7 +190,7 @@ struct ivf_pq_knn {
IdxT* out_idxs)
{
search_params.n_probes = 20;
- raft::spatial::knn::ivf_pq::search(
+ raft::neighbors::ivf_pq::search(
handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists);
}
};
diff --git a/cpp/bench/spatial/knn/brute_force_float_int64_t.cu b/cpp/bench/neighbors/knn/brute_force_float_int64_t.cu
similarity index 100%
rename from cpp/bench/spatial/knn/brute_force_float_int64_t.cu
rename to cpp/bench/neighbors/knn/brute_force_float_int64_t.cu
diff --git a/cpp/bench/spatial/knn/brute_force_float_uint32_t.cu b/cpp/bench/neighbors/knn/brute_force_float_uint32_t.cu
similarity index 100%
rename from cpp/bench/spatial/knn/brute_force_float_uint32_t.cu
rename to cpp/bench/neighbors/knn/brute_force_float_uint32_t.cu
diff --git a/cpp/bench/spatial/knn/ivf_flat_float_int64_t.cu b/cpp/bench/neighbors/knn/ivf_flat_float_int64_t.cu
similarity index 100%
rename from cpp/bench/spatial/knn/ivf_flat_float_int64_t.cu
rename to cpp/bench/neighbors/knn/ivf_flat_float_int64_t.cu
diff --git a/cpp/bench/spatial/knn/ivf_flat_float_uint32_t.cu b/cpp/bench/neighbors/knn/ivf_flat_float_uint32_t.cu
similarity index 100%
rename from cpp/bench/spatial/knn/ivf_flat_float_uint32_t.cu
rename to cpp/bench/neighbors/knn/ivf_flat_float_uint32_t.cu
diff --git a/cpp/bench/spatial/knn/ivf_flat_int8_t_int64_t.cu b/cpp/bench/neighbors/knn/ivf_flat_int8_t_int64_t.cu
similarity index 100%
rename from cpp/bench/spatial/knn/ivf_flat_int8_t_int64_t.cu
rename to cpp/bench/neighbors/knn/ivf_flat_int8_t_int64_t.cu
diff --git a/cpp/bench/spatial/knn/ivf_flat_uint8_t_uint32_t.cu b/cpp/bench/neighbors/knn/ivf_flat_uint8_t_uint32_t.cu
similarity index 100%
rename from cpp/bench/spatial/knn/ivf_flat_uint8_t_uint32_t.cu
rename to cpp/bench/neighbors/knn/ivf_flat_uint8_t_uint32_t.cu
diff --git a/cpp/bench/spatial/knn/ivf_pq_float_int64_t.cu b/cpp/bench/neighbors/knn/ivf_pq_float_int64_t.cu
similarity index 100%
rename from cpp/bench/spatial/knn/ivf_pq_float_int64_t.cu
rename to cpp/bench/neighbors/knn/ivf_pq_float_int64_t.cu
diff --git a/cpp/bench/spatial/knn/ivf_pq_float_uint32_t.cu b/cpp/bench/neighbors/knn/ivf_pq_float_uint32_t.cu
similarity index 100%
rename from cpp/bench/spatial/knn/ivf_pq_float_uint32_t.cu
rename to cpp/bench/neighbors/knn/ivf_pq_float_uint32_t.cu
diff --git a/cpp/bench/spatial/knn/ivf_pq_int8_t_int64_t.cu b/cpp/bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu
similarity index 100%
rename from cpp/bench/spatial/knn/ivf_pq_int8_t_int64_t.cu
rename to cpp/bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu
diff --git a/cpp/bench/spatial/knn/ivf_pq_uint8_t_uint32_t.cu b/cpp/bench/neighbors/knn/ivf_pq_uint8_t_uint32_t.cu
similarity index 100%
rename from cpp/bench/spatial/knn/ivf_pq_uint8_t_uint32_t.cu
rename to cpp/bench/neighbors/knn/ivf_pq_uint8_t_uint32_t.cu
diff --git a/cpp/bench/spatial/selection.cu b/cpp/bench/neighbors/selection.cu
similarity index 100%
rename from cpp/bench/spatial/selection.cu
rename to cpp/bench/neighbors/selection.cu
diff --git a/cpp/include/raft/cluster/detail/connectivities.cuh b/cpp/include/raft/cluster/detail/connectivities.cuh
index da8adf783d..a07045f0d2 100644
--- a/cpp/include/raft/cluster/detail/connectivities.cuh
+++ b/cpp/include/raft/cluster/detail/connectivities.cuh
@@ -27,7 +27,7 @@
#include
#include
#include
-#include
+#include
#include
#include
@@ -73,7 +73,7 @@ struct distance_graph_impl knn_graph_coo(stream);
- raft::sparse::spatial::knn_graph(handle, X, m, n, metric, knn_graph_coo, c);
+ raft::sparse::neighbors::knn_graph(handle, X, m, n, metric, knn_graph_coo, c);
indices.resize(knn_graph_coo.nnz, stream);
data.resize(knn_graph_coo.nnz, stream);
diff --git a/cpp/include/raft/cluster/detail/mst.cuh b/cpp/include/raft/cluster/detail/mst.cuh
index 67935d4623..8143d21641 100644
--- a/cpp/include/raft/cluster/detail/mst.cuh
+++ b/cpp/include/raft/cluster/detail/mst.cuh
@@ -19,9 +19,9 @@
#include
#include
+#include
#include
#include
-#include
#include
#include
@@ -80,7 +80,7 @@ void connect_knn_graph(
raft::sparse::COO connected_edges(stream);
- raft::sparse::spatial::connect_components(
+ raft::sparse::neighbors::connect_components(
handle, connected_edges, X, color, m, n, reduction_op);
rmm::device_uvector indptr2(m + 1, stream);
@@ -153,14 +153,14 @@ void build_sorted_mst(
handle, indptr, indices, pw_dists, (value_idx)m, nnz, color, stream, false, true);
int iters = 1;
- int n_components = raft::sparse::spatial::get_n_components(color, m, stream);
+ int n_components = raft::sparse::neighbors::get_n_components(color, m, stream);
while (n_components > 1 && iters < max_iter) {
connect_knn_graph(handle, X, mst_coo, m, n, color, reduction_op);
iters++;
- n_components = raft::sparse::spatial::get_n_components(color, m, stream);
+ n_components = raft::sparse::neighbors::get_n_components(color, m, stream);
}
/**
diff --git a/cpp/include/raft/cluster/detail/single_linkage.cuh b/cpp/include/raft/cluster/detail/single_linkage.cuh
index 9eee21b09c..d12db85e1b 100644
--- a/cpp/include/raft/cluster/detail/single_linkage.cuh
+++ b/cpp/include/raft/cluster/detail/single_linkage.cuh
@@ -80,7 +80,7 @@ void single_linkage(const raft::handle_t& handle,
* 2. Construct MST, sorted by weights
*/
rmm::device_uvector color(m, stream);
- raft::sparse::spatial::FixConnectivitiesRedOp op(color.data(), m);
+ raft::sparse::neighbors::FixConnectivitiesRedOp op(color.data(), m);
detail::build_sorted_mst(handle,
X,
indptr.data(),
diff --git a/cpp/include/raft/neighbors/ann_types.hpp b/cpp/include/raft/neighbors/ann_types.hpp
new file mode 100644
index 0000000000..5c6fd52be9
--- /dev/null
+++ b/cpp/include/raft/neighbors/ann_types.hpp
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+
+namespace raft::neighbors::ann {
+
+/** The base for approximate KNN index structures. */
+struct index {
+};
+
+/** The base for KNN index parameters. */
+struct index_params {
+ /** Distance type. */
+ raft::distance::DistanceType metric = distance::DistanceType::L2Expanded;
+ /** The argument used by some distance metrics. */
+ float metric_arg = 2.0f;
+ /**
+ * Whether to add the dataset content to the index, i.e.:
+ *
+ * - `true` means the index is filled with the dataset vectors and ready to search after calling
+ * `build`.
+ * - `false` means `build` only trains the underlying model (e.g. quantizer or clustering), but
+ * the index is left empty; you'd need to call `extend` on the index afterwards to populate it.
+ */
+ bool add_data_on_build = true;
+};
+
+struct search_params {
+};
+
+}; // namespace raft::neighbors::ann
diff --git a/cpp/include/raft/neighbors/ball_cover.cuh b/cpp/include/raft/neighbors/ball_cover.cuh
new file mode 100644
index 0000000000..780a9cfce2
--- /dev/null
+++ b/cpp/include/raft/neighbors/ball_cover.cuh
@@ -0,0 +1,314 @@
+/*
+ * Copyright (c) 2021-2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef __BALL_COVER_H
+#define __BALL_COVER_H
+
+#pragma once
+
+#include
+
+#include "ball_cover_types.hpp"
+#include
+#include
+#include
+#include
+
+namespace raft::neighbors::ball_cover {
+
+/**
+ * Builds and populates a previously unbuilt BallCoverIndex
+ * @tparam idx_t knn index type
+ * @tparam value_t knn value type
+ * @tparam int_t integral type for knn params
+ * @tparam matrix_idx_t matrix indexing type
+ * @param[in] handle library resource management handle
+ * @param[inout] index an empty (and not previous built) instance of BallCoverIndex
+ */
+template
+void build_index(const raft::handle_t& handle,
+ BallCoverIndex& index)
+{
+ ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
+ if (index.metric == raft::distance::DistanceType::Haversine) {
+ raft::spatial::knn::detail::rbc_build_index(
+ handle, index, spatial::knn::detail::HaversineFunc());
+ } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded ||
+ index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) {
+ raft::spatial::knn::detail::rbc_build_index(
+ handle, index, spatial::knn::detail::EuclideanFunc());
+ } else {
+ RAFT_FAIL("Metric not support");
+ }
+
+ index.set_index_trained();
+}
+
+/**
+ * Performs a faster exact knn in metric spaces using the triangle
+ * inequality with a number of landmark points to reduce the
+ * number of distance computations from O(n^2) to O(sqrt(n)). This
+ * performs an all neighbors knn, which can reuse memory when
+ * the index and query are the same array. This function will
+ * build the index and assumes rbc_build_index() has not already
+ * been called.
+ * @tparam idx_t knn index type
+ * @tparam value_t knn distance type
+ * @tparam int_t type for integers, such as number of rows/cols
+ * @param[in] handle raft handle for resource management
+ * @param[inout] index ball cover index which has not yet been built
+ * @param[in] k number of nearest neighbors to find
+ * @param[in] perform_post_filtering if this is false, only the closest k landmarks
+ * are considered (which will return approximate
+ * results).
+ * @param[out] inds output knn indices
+ * @param[out] dists output knn distances
+ * @param[in] weight a weight for overlap between the closest landmark and
+ * the radius of other landmarks when pruning distances.
+ * Setting this value below 1 can effectively turn off
+ * computing distances against many other balls, enabling
+ * approximate nearest neighbors. Recall can be adjusted
+ * based on how many relevant balls are ignored. Note that
+ * many datasets can still have great recall even by only
+ * looking in the closest landmark.
+ */
+template
+void all_knn_query(const raft::handle_t& handle,
+ BallCoverIndex& index,
+ int_t k,
+ idx_t* inds,
+ value_t* dists,
+ bool perform_post_filtering = true,
+ float weight = 1.0)
+{
+ ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
+ if (index.metric == raft::distance::DistanceType::Haversine) {
+ raft::spatial::knn::detail::rbc_all_knn_query(
+ handle,
+ index,
+ k,
+ inds,
+ dists,
+ spatial::knn::detail::HaversineFunc(),
+ perform_post_filtering,
+ weight);
+ } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded ||
+ index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) {
+ raft::spatial::knn::detail::rbc_all_knn_query(
+ handle,
+ index,
+ k,
+ inds,
+ dists,
+ spatial::knn::detail::EuclideanFunc(),
+ perform_post_filtering,
+ weight);
+ } else {
+ RAFT_FAIL("Metric not supported");
+ }
+
+ index.set_index_trained();
+}
+
+/**
+ * Performs a faster exact knn in metric spaces using the triangle
+ * inequality with a number of landmark points to reduce the
+ * number of distance computations from O(n^2) to O(sqrt(n)). This
+ * performs an all neighbors knn, which can reuse memory when
+ * the index and query are the same array. This function will
+ * build the index and assumes rbc_build_index() has not already
+ * been called.
+ * @tparam idx_t knn index type
+ * @tparam value_t knn distance type
+ * @tparam int_t type for integers, such as number of rows/cols
+ * @tparam matrix_idx_t matrix indexing type
+ * @param[in] handle raft handle for resource management
+ * @param[in] index ball cover index which has not yet been built
+ * @param[out] inds output knn indices
+ * @param[out] dists output knn distances
+ * @param[in] k number of nearest neighbors to find
+ * @param[in] perform_post_filtering if this is false, only the closest k landmarks
+ * are considered (which will return approximate
+ * results).
+ * @param[in] weight a weight for overlap between the closest landmark and
+ * the radius of other landmarks when pruning distances.
+ * Setting this value below 1 can effectively turn off
+ * computing distances against many other balls, enabling
+ * approximate nearest neighbors. Recall can be adjusted
+ * based on how many relevant balls are ignored. Note that
+ * many datasets can still have great recall even by only
+ * looking in the closest landmark.
+ */
+template
+void all_knn_query(const raft::handle_t& handle,
+ BallCoverIndex& index,
+ raft::device_matrix_view inds,
+ raft::device_matrix_view dists,
+ int_t k,
+ bool perform_post_filtering = true,
+ float weight = 1.0)
+{
+ RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
+ RAFT_EXPECTS(k <= index.m,
+ "k must be less than or equal to the number of data points in the index");
+ RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k),
+ "Number of columns in output indices and distances matrices must be equal to k");
+
+ RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == index.get_X().extent(0),
+ "Number of rows in output indices and distances matrices must equal number of rows "
+ "in index matrix.");
+
+ all_knn_query(
+ handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, weight);
+}
+
+/**
+ * Performs a faster exact knn in metric spaces using the triangle
+ * inequality with a number of landmark points to reduce the
+ * number of distance computations from O(n^2) to O(sqrt(n)). This
+ * function does not build the index and assumes rbc_build_index() has
+ * already been called. Use this function when the index and
+ * query arrays are different, otherwise use rbc_all_knn_query().
+ * @tparam idx_t index type
+ * @tparam value_t distances type
+ * @tparam int_t integer type for size info
+ * @param[in] handle raft handle for resource management
+ * @param[inout] index ball cover index which has not yet been built
+ * @param[in] k number of nearest neighbors to find
+ * @param[in] query the
+ * @param[in] perform_post_filtering if this is false, only the closest k landmarks
+ * are considered (which will return approximate
+ * results).
+ * @param[out] inds output knn indices
+ * @param[out] dists output knn distances
+ * @param[in] weight a weight for overlap between the closest landmark and
+ * the radius of other landmarks when pruning distances.
+ * Setting this value below 1 can effectively turn off
+ * computing distances against many other balls, enabling
+ * approximate nearest neighbors. Recall can be adjusted
+ * based on how many relevant balls are ignored. Note that
+ * many datasets can still have great recall even by only
+ * looking in the closest landmark.
+ * @param[in] n_query_pts number of query points
+ */
+template
+void knn_query(const raft::handle_t& handle,
+ const BallCoverIndex& index,
+ int_t k,
+ const value_t* query,
+ int_t n_query_pts,
+ idx_t* inds,
+ value_t* dists,
+ bool perform_post_filtering = true,
+ float weight = 1.0)
+{
+ ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
+ if (index.metric == raft::distance::DistanceType::Haversine) {
+ raft::spatial::knn::detail::rbc_knn_query(handle,
+ index,
+ k,
+ query,
+ n_query_pts,
+ inds,
+ dists,
+ spatial::knn::detail::HaversineFunc(),
+ perform_post_filtering,
+ weight);
+ } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded ||
+ index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) {
+ raft::spatial::knn::detail::rbc_knn_query(handle,
+ index,
+ k,
+ query,
+ n_query_pts,
+ inds,
+ dists,
+ spatial::knn::detail::EuclideanFunc(),
+ perform_post_filtering,
+ weight);
+ } else {
+ RAFT_FAIL("Metric not supported");
+ }
+}
+
+/**
+ * Performs a faster exact knn in metric spaces using the triangle
+ * inequality with a number of landmark points to reduce the
+ * number of distance computations from O(n^2) to O(sqrt(n)). This
+ * function does not build the index and assumes rbc_build_index() has
+ * already been called. Use this function when the index and
+ * query arrays are different, otherwise use rbc_all_knn_query().
+ * @tparam idx_t index type
+ * @tparam value_t distances type
+ * @tparam int_t integer type for size info
+ * @tparam matrix_idx_t
+ * @param[in] handle raft handle for resource management
+ * @param[in] index ball cover index which has not yet been built
+ * @param[in] query device matrix containing query data points
+ * @param[out] inds output knn indices
+ * @param[out] dists output knn distances
+ * @param[in] k number of nearest neighbors to find
+ * @param[in] perform_post_filtering if this is false, only the closest k landmarks
+ * are considered (which will return approximate
+ * results).
+ * @param[in] weight a weight for overlap between the closest landmark and
+ * the radius of other landmarks when pruning distances.
+ * Setting this value below 1 can effectively turn off
+ * computing distances against many other balls, enabling
+ * approximate nearest neighbors. Recall can be adjusted
+ * based on how many relevant balls are ignored. Note that
+ * many datasets can still have great recall even by only
+ * looking in the closest landmark.
+ */
+template
+void knn_query(const raft::handle_t& handle,
+ const BallCoverIndex& index,
+ raft::device_matrix_view query,
+ raft::device_matrix_view inds,
+ raft::device_matrix_view dists,
+ int_t k,
+ bool perform_post_filtering = true,
+ float weight = 1.0)
+{
+ RAFT_EXPECTS(k <= index.m,
+ "k must be less than or equal to the number of data points in the index");
+ RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast(k),
+ "Number of columns in output indices and distances matrices must be equal to k");
+
+ RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == query.extent(0),
+ "Number of rows in output indices and distances matrices must equal number of rows "
+ "in search matrix.");
+
+ RAFT_EXPECTS(query.extent(1) == index.get_X().extent(1),
+ "Number of columns in query and index matrices must match.");
+
+ knn_query(handle,
+ index,
+ k,
+ query.data_handle(),
+ query.extent(0),
+ inds.data_handle(),
+ dists.data_handle(),
+ perform_post_filtering,
+ weight);
+}
+
+// TODO: implement functions for:
+// 4. rbc_eps_neigh() - given a populated index, perform query against different query array
+// 5. rbc_all_eps_neigh() - populate a BallCoverIndex and query against training data
+
+} // namespace raft::neighbors::ball_cover
+
+#endif
diff --git a/cpp/include/raft/neighbors/ball_cover_types.hpp b/cpp/include/raft/neighbors/ball_cover_types.hpp
new file mode 100644
index 0000000000..f6e49ab5c4
--- /dev/null
+++ b/cpp/include/raft/neighbors/ball_cover_types.hpp
@@ -0,0 +1,161 @@
+/*
+ * Copyright (c) 2021-2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace raft::neighbors::ball_cover {
+
+/**
+ * Stores raw index data points, sampled landmarks, the 1-nns of index points
+ * to their closest landmarks, and the ball radii of each landmark. This
+ * class is intended to be constructed once and reused across subsequent
+ * queries.
+ * @tparam value_idx
+ * @tparam value_t
+ * @tparam value_int
+ */
+template
+class BallCoverIndex {
+ public:
+ explicit BallCoverIndex(const raft::handle_t& handle_,
+ const value_t* X_,
+ value_int m_,
+ value_int n_,
+ raft::distance::DistanceType metric_)
+ : handle(handle_),
+ X(raft::make_device_matrix_view(X_, m_, n_)),
+ m(m_),
+ n(n_),
+ metric(metric_),
+ /**
+ * the sqrt() here makes the sqrt(m)^2 a linear-time lower bound
+ *
+ * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m)
+ */
+ n_landmarks(sqrt(m_)),
+ R_indptr(raft::make_device_vector(handle, sqrt(m_) + 1)),
+ R_1nn_cols(raft::make_device_vector(handle, m_)),
+ R_1nn_dists(raft::make_device_vector(handle, m_)),
+ R_closest_landmark_dists(raft::make_device_vector(handle, m_)),
+ R(raft::make_device_matrix(handle, sqrt(m_), n_)),
+ R_radius(raft::make_device_vector(handle, sqrt(m_))),
+ index_trained(false)
+ {
+ }
+
+ explicit BallCoverIndex(const raft::handle_t& handle_,
+ raft::device_matrix_view X_,
+ raft::distance::DistanceType metric_)
+ : handle(handle_),
+ X(X_),
+ m(X_.extent(0)),
+ n(X_.extent(1)),
+ metric(metric_),
+ /**
+ * the sqrt() here makes the sqrt(m)^2 a linear-time lower bound
+ *
+ * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m)
+ */
+ n_landmarks(sqrt(X_.extent(0))),
+ R_indptr(raft::make_device_vector(handle, sqrt(X_.extent(0)) + 1)),
+ R_1nn_cols(raft::make_device_vector(handle, X_.extent(0))),
+ R_1nn_dists(raft::make_device_vector(handle, X_.extent(0))),
+ R_closest_landmark_dists(raft::make_device_vector(handle, X_.extent(0))),
+ R(raft::make_device_matrix(handle, sqrt(X_.extent(0)), X_.extent(1))),
+ R_radius(raft::make_device_vector(handle, sqrt(X_.extent(0)))),
+ index_trained(false)
+ {
+ }
+
+ auto get_R_indptr() const -> raft::device_vector_view
+ {
+ return R_indptr.view();
+ }
+ auto get_R_1nn_cols() const -> raft::device_vector_view
+ {
+ return R_1nn_cols.view();
+ }
+ auto get_R_1nn_dists() const -> raft::device_vector_view
+ {
+ return R_1nn_dists.view();
+ }
+ auto get_R_radius() const -> raft::device_vector_view
+ {
+ return R_radius.view();
+ }
+ auto get_R() const -> raft::device_matrix_view
+ {
+ return R.view();
+ }
+ auto get_R_closest_landmark_dists() const -> raft::device_vector_view
+ {
+ return R_closest_landmark_dists.view();
+ }
+
+ raft::device_vector_view get_R_indptr() { return R_indptr.view(); }
+ raft::device_vector_view get_R_1nn_cols() { return R_1nn_cols.view(); }
+ raft::device_vector_view get_R_1nn_dists() { return R_1nn_dists.view(); }
+ raft::device_vector_view get_R_radius() { return R_radius.view(); }
+ raft::device_matrix_view get_R() { return R.view(); }
+ raft::device_vector_view get_R_closest_landmark_dists()
+ {
+ return R_closest_landmark_dists.view();
+ }
+ raft::device_matrix_view get_X() const { return X; }
+
+ raft::distance::DistanceType get_metric() const { return metric; }
+
+ value_int get_n_landmarks() const { return n_landmarks; }
+ bool is_index_trained() const { return index_trained; };
+
+ // This should only be set by internal functions
+ void set_index_trained() { index_trained = true; }
+
+ const raft::handle_t& handle;
+
+ value_int m;
+ value_int n;
+ value_int n_landmarks;
+
+ raft::device_matrix_view X;
+
+ raft::distance::DistanceType metric;
+
+ private:
+ // CSR storing the neighborhoods for each data point
+ raft::device_vector R_indptr;
+ raft::device_vector R_1nn_cols;
+ raft::device_vector R_1nn_dists;
+ raft::device_vector R_closest_landmark_dists;
+
+ raft::device_vector R_radius;
+
+ raft::device_matrix R;
+
+ protected:
+ bool index_trained;
+};
+} // namespace raft::neighbors::ball_cover
diff --git a/cpp/include/raft/spatial/knn/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh
similarity index 65%
rename from cpp/include/raft/spatial/knn/brute_force.cuh
rename to cpp/include/raft/neighbors/brute_force.cuh
index dda1e02eed..3641a38991 100644
--- a/cpp/include/raft/spatial/knn/brute_force.cuh
+++ b/cpp/include/raft/neighbors/brute_force.cuh
@@ -16,11 +16,11 @@
#pragma once
-#include "detail/knn_brute_force_faiss.cuh"
-#include "detail/selection_faiss.cuh"
#include
+#include
+#include
-namespace raft::spatial::knn {
+namespace raft::neighbors::brute_force {
/**
* @brief Performs a k-select across row partitioned index/distance
@@ -63,15 +63,15 @@ inline void knn_merge_parts(
"Number of columns in output indices and distances matrices must be equal to k");
auto n_parts = in_keys.extent(0) / n_samples;
- detail::knn_merge_parts(in_keys.data_handle(),
- in_values.data_handle(),
- out_keys.data_handle(),
- out_values.data_handle(),
- n_samples,
- n_parts,
- in_keys.extent(1),
- handle.get_stream(),
- translations.value_or(nullptr));
+ spatial::knn::detail::knn_merge_parts(in_keys.data_handle(),
+ in_values.data_handle(),
+ out_keys.data_handle(),
+ out_values.data_handle(),
+ n_samples,
+ n_parts,
+ in_keys.extent(1),
+ handle.get_stream(),
+ translations.value_or(nullptr));
}
/**
@@ -99,16 +99,15 @@ template
-void brute_force_knn(
- raft::handle_t const& handle,
- std::vector> index,
- raft::device_matrix_view search,
- raft::device_matrix_view indices,
- raft::device_matrix_view distances,
- value_int k,
- distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
- std::optional metric_arg = std::make_optional(2.0f),
- std::optional> translations = std::nullopt)
+void knn(raft::handle_t const& handle,
+ std::vector> index,
+ raft::device_matrix_view search,
+ raft::device_matrix_view indices,
+ raft::device_matrix_view distances,
+ value_int k,
+ distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
+ std::optional metric_arg = std::make_optional(2.0f),
+ std::optional> translations = std::nullopt)
{
RAFT_EXPECTS(index[0].extent(1) == search.extent(1),
"Number of dimensions for both index and search matrices must be equal");
@@ -132,21 +131,21 @@ void brute_force_knn(
std::vector* trans = translations.has_value() ? &(*translations) : nullptr;
- detail::brute_force_knn_impl(handle,
- inputs,
- sizes,
- static_cast(index[0].extent(1)),
- // TODO: This is unfortunate. Need to fix.
- const_cast(search.data_handle()),
- static_cast(search.extent(0)),
- indices.data_handle(),
- distances.data_handle(),
- k,
- rowMajorIndex,
- rowMajorQuery,
- trans,
- metric,
- metric_arg.value_or(2.0f));
+ raft::spatial::knn::detail::brute_force_knn_impl(handle,
+ inputs,
+ sizes,
+ static_cast(index[0].extent(1)),
+ // TODO: This is unfortunate. Need to fix.
+ const_cast(search.data_handle()),
+ static_cast(search.extent(0)),
+ indices.data_handle(),
+ distances.data_handle(),
+ k,
+ rowMajorIndex,
+ rowMajorQuery,
+ trans,
+ metric,
+ metric_arg.value_or(2.0f));
}
-} // namespace raft::spatial::knn
+} // namespace raft::neighbors::brute_force
diff --git a/cpp/include/raft/neighbors/epsilon_neighborhood.cuh b/cpp/include/raft/neighbors/epsilon_neighborhood.cuh
new file mode 100644
index 0000000000..b0e9b842ec
--- /dev/null
+++ b/cpp/include/raft/neighbors/epsilon_neighborhood.cuh
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __EPSILON_NEIGH_H
+#define __EPSILON_NEIGH_H
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace raft::neighbors::epsilon_neighborhood {
+
+/**
+ * @brief Computes epsilon neighborhood for the L2-Squared distance metric
+ *
+ * @tparam value_t IO and math type
+ * @tparam idx_t Index type
+ *
+ * @param[out] adj adjacency matrix [row-major] [on device] [dim = m x n]
+ * @param[out] vd vertex degree array [on device] [len = m + 1]
+ * `vd + m` stores the total number of edges in the adjacency
+ * matrix. Pass a nullptr if you don't need this info.
+ * @param[in] x first matrix [row-major] [on device] [dim = m x k]
+ * @param[in] y second matrix [row-major] [on device] [dim = n x k]
+ * @param[in] m number of rows in x
+ * @param[in] n number of rows in y
+ * @param[in] k number of columns in x and k
+ * @param[in] eps defines epsilon neighborhood radius (should be passed as
+ * squared as we compute L2-squared distance in this method)
+ * @param[in] stream cuda stream
+ */
+template
+void epsUnexpL2SqNeighborhood(bool* adj,
+ idx_t* vd,
+ const value_t* x,
+ const value_t* y,
+ idx_t m,
+ idx_t n,
+ idx_t k,
+ value_t eps,
+ cudaStream_t stream)
+{
+ spatial::knn::detail::epsUnexpL2SqNeighborhood(
+ adj, vd, x, y, m, n, k, eps, stream);
+}
+
+/**
+ * @brief Computes epsilon neighborhood for the L2-Squared distance metric
+ *
+ * @tparam value_t IO and math type
+ * @tparam idx_t Index type
+ * @tparam matrix_idx_t matrix indexing type
+ *
+ * @param[in] handle raft handle to manage library resources
+ * @param[in] x first matrix [row-major] [on device] [dim = m x k]
+ * @param[in] y second matrix [row-major] [on device] [dim = n x k]
+ * @param[out] adj adjacency matrix [row-major] [on device] [dim = m x n]
+ * @param[out] vd vertex degree array [on device] [len = m + 1]
+ * `vd + m` stores the total number of edges in the adjacency
+ * matrix. Pass a nullptr if you don't need this info.
+ * @param[in] eps defines epsilon neighborhood radius (should be passed as
+ * squared as we compute L2-squared distance in this method)
+ */
+template
+void eps_neighbors_l2sq(const raft::handle_t& handle,
+ raft::device_matrix_view x,
+ raft::device_matrix_view y,
+ raft::device_matrix_view adj,
+ raft::device_vector_view vd,
+ value_t eps)
+{
+ epsUnexpL2SqNeighborhood(adj.data_handle(),
+ vd.data_handle(),
+ x.data_handle(),
+ y.data_handle(),
+ x.extent(0),
+ y.extent(0),
+ x.extent(1),
+ eps,
+ handle.get_stream());
+}
+
+} // namespace raft::neighbors::epsilon_neighborhood
+
+#endif
\ No newline at end of file
diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh
new file mode 100644
index 0000000000..23ae6c42bf
--- /dev/null
+++ b/cpp/include/raft/neighbors/ivf_flat.cuh
@@ -0,0 +1,387 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "ivf_flat_types.hpp"
+#include
+#include
+
+#include
+
+#include
+#include
+#include
+
+namespace raft::neighbors::ivf_flat {
+
+/**
+ * @brief Build the index from the dataset for efficient search.
+ *
+ * NB: Currently, the following distance metrics are supported:
+ * - L2Expanded
+ * - L2Unexpanded
+ * - InnerProduct
+ *
+ * Usage example:
+ * @code{.cpp}
+ * using namespace raft::spatial::knn;
+ * // use default index parameters
+ * ivf_flat::index_params index_params;
+ * // create and fill the index from a [N, D] dataset
+ * auto index = ivf_flat::build(handle, index_params, dataset, N, D);
+ * // use default search parameters
+ * ivf_flat::search_params search_params;
+ * // search K nearest neighbours for each of the N queries
+ * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists);
+ * @endcode
+ *
+ * @tparam T data element type
+ * @tparam IdxT type of the indices in the source dataset
+ *
+ * @param[in] handle
+ * @param[in] params configure the index building
+ * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
+ * @param[in] n_rows the number of samples
+ * @param[in] dim the dimensionality of the data
+ *
+ * @return the constructed ivf-flat index
+ */
+template
+inline auto build(
+ const handle_t& handle, const index_params& params, const T* dataset, IdxT n_rows, uint32_t dim)
+ -> index
+{
+ return raft::spatial::knn::ivf_flat::detail::build(handle, params, dataset, n_rows, dim);
+}
+
+/**
+ * @brief Build the index from the dataset for efficient search.
+ *
+ * NB: Currently, the following distance metrics are supported:
+ * - L2Expanded
+ * - L2Unexpanded
+ * - InnerProduct
+ *
+ * Usage example:
+ * @code{.cpp}
+ * using namespace raft::spatial::knn;
+ * // use default index parameters
+ * ivf_flat::index_params index_params;
+ * // create and fill the index from a [N, D] dataset
+ * auto index = ivf_flat::build(handle, index_params, dataset, N, D);
+ * // use default search parameters
+ * ivf_flat::search_params search_params;
+ * // search K nearest neighbours for each of the N queries
+ * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists);
+ * @endcode
+ *
+ * @tparam value_t data element type
+ * @tparam idx_t type of the indices in the source dataset
+ * @tparam int_t precision / type of integral arguments
+ * @tparam matrix_idx_t matrix indexing type
+ *
+ * @param[in] handle
+ * @param[in] params configure the index building
+ * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
+ *
+ * @return the constructed ivf-flat index
+ */
+template
+auto build_index(const handle_t& handle,
+ raft::device_matrix_view dataset,
+ const index_params& params) -> index
+{
+ return raft::spatial::knn::ivf_flat::detail::build(handle,
+ params,
+ dataset.data_handle(),
+ static_cast(dataset.extent(0)),
+ static_cast(dataset.extent(1)));
+}
+
+/**
+ * @brief Build a new index containing the data of the original plus new extra vectors.
+ *
+ * Implementation note:
+ * The new data is clustered according to existing kmeans clusters, then the cluster
+ * centers are adjusted to match the newly labeled data.
+ *
+ * Usage example:
+ * @code{.cpp}
+ * using namespace raft::spatial::knn;
+ * ivf_flat::index_params index_params;
+ * index_params.add_data_on_build = false; // don't populate index on build
+ * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training
+ * // train the index from a [N, D] dataset
+ * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D);
+ * // fill the index with the data
+ * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N);
+ * @endcode
+ *
+ * @tparam T data element type
+ * @tparam IdxT type of the indices in the source dataset
+ *
+ * @param[in] handle
+ * @param[in] orig_index original index
+ * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()]
+ * @param[in] new_indices a device pointer to a vector of indices [n_rows].
+ * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
+ * here to imply a continuous range `[0...n_rows)`.
+ * @param[in] n_rows number of rows in `new_vectors`
+ *
+ * @return the constructed extended ivf-flat index
+ */
+template
+inline auto extend(const handle_t& handle,
+ const index& orig_index,
+ const T* new_vectors,
+ const IdxT* new_indices,
+ IdxT n_rows) -> index
+{
+ return raft::spatial::knn::ivf_flat::detail::extend(
+ handle, orig_index, new_vectors, new_indices, n_rows);
+}
+
+/**
+ * @brief Build a new index containing the data of the original plus new extra vectors.
+ *
+ * Implementation note:
+ * The new data is clustered according to existing kmeans clusters, then the cluster
+ * centers are adjusted to match the newly labeled data.
+ *
+ * Usage example:
+ * @code{.cpp}
+ * using namespace raft::spatial::knn;
+ * ivf_flat::index_params index_params;
+ * index_params.add_data_on_build = false; // don't populate index on build
+ * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training
+ * // train the index from a [N, D] dataset
+ * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D);
+ * // fill the index with the data
+ * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N);
+ * @endcode
+ *
+ * @tparam value_t data element type
+ * @tparam idx_t type of the indices in the source dataset
+ * @tparam int_t precision / type of integral arguments
+ * @tparam matrix_idx_t matrix indexing type
+ *
+ * @param[in] handle
+ * @param[in] orig_index original index
+ * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()]
+ * @param[in] new_indices a device pointer to a vector of indices [n_rows].
+ * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
+ * here to imply a continuous range `[0...n_rows)`.
+ *
+ * @return the constructed extended ivf-flat index
+ */
+template
+auto extend(const handle_t& handle,
+ const index& orig_index,
+ raft::device_matrix_view new_vectors,
+ std::optional> new_indices = std::nullopt)
+ -> index
+{
+ return raft::spatial::knn::ivf_flat::detail::extend(
+ handle,
+ orig_index,
+ new_vectors.data_handle(),
+ new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
+ new_vectors.extent(0));
+}
+
+/**
+ * @brief Extend the index with the new data.
+ * *
+ * @tparam T data element type
+ * @tparam IdxT type of the indices in the source dataset
+ *
+ * @param handle
+ * @param[inout] index
+ * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()]
+ * @param[in] new_indices a device pointer to a vector of indices [n_rows].
+ * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
+ * here to imply a continuous range `[0...n_rows)`.
+ * @param[in] n_rows the number of samples
+ */
+template
+inline void extend(const handle_t& handle,
+ index* index,
+ const T* new_vectors,
+ const IdxT* new_indices,
+ IdxT n_rows)
+{
+ *index = extend(handle, *index, new_vectors, new_indices, n_rows);
+}
+
+/**
+ * @brief Extend the index with the new data.
+ * *
+ * @tparam value_t data element type
+ * @tparam idx_t type of the indices in the source dataset
+ * @tparam int_t precision / type of integral arguments
+ * @tparam matrix_idx_t matrix indexing type
+ *
+ * @param[in] handle
+ * @param[inout] index
+ * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()]
+ * @param[in] new_indices a device pointer to a vector of indices [n_rows].
+ * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt`
+ * here to imply a continuous range `[0...n_rows)`.
+ */
+template
+void extend(const handle_t& handle,
+ index* index,
+ raft::device_matrix_view new_vectors,
+ std::optional> new_indices = std::nullopt)
+{
+ *index = extend(handle,
+ *index,
+ new_vectors.data_handle(),
+ new_indices.has_value() ? new_indices.value().data_handle() : nullptr,
+ static_cast(new_vectors.extent(0)));
+}
+
+/**
+ * @brief Search ANN using the constructed index.
+ *
+ * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example.
+ *
+ * Note, this function requires a temporary buffer to store intermediate results between cuda kernel
+ * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
+ * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
+ * eliminate entirely allocations happening within `search`:
+ * @code{.cpp}
+ * ...
+ * // Create a pooling memory resource with a pre-defined initial size.
+ * rmm::mr::pool_memory_resource mr(
+ * rmm::mr::get_current_device_resource(), 1024 * 1024);
+ * // use default search parameters
+ * ivf_flat::search_params search_params;
+ * // Use the same allocator across multiple searches to reduce the number of
+ * // cuda memory allocations
+ * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr);
+ * ivf_flat::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr);
+ * ivf_flat::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr);
+ * ...
+ * @endcode
+ * The exact size of the temporary buffer depends on multiple factors and is an implementation
+ * detail. However, you can safely specify a small initial size for the memory pool, so that only a
+ * few allocations happen to grow it during the first invocations of the `search`.
+ *
+ * @tparam T data element type
+ * @tparam IdxT type of the indices
+ *
+ * @param[in] handle
+ * @param[in] params configure the search
+ * @param[in] index ivf-flat constructed index
+ * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
+ * @param[in] n_queries the batch size
+ * @param[in] k the number of neighbors to find for each query.
+ * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
+ * [n_queries, k]
+ * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
+ * @param[in] mr an optional memory resource to use across the searches (you can provide a large
+ * enough memory pool here to avoid memory allocations within search).
+ */
+template
+inline void search(const handle_t& handle,
+ const search_params& params,
+ const index& index,
+ const T* queries,
+ uint32_t n_queries,
+ uint32_t k,
+ IdxT* neighbors,
+ float* distances,
+ rmm::mr::device_memory_resource* mr = nullptr)
+{
+ return raft::spatial::knn::ivf_flat::detail::search(
+ handle, params, index, queries, n_queries, k, neighbors, distances, mr);
+}
+
+/**
+ * @brief Search ANN using the constructed index.
+ *
+ * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example.
+ *
+ * Note, this function requires a temporary buffer to store intermediate results between cuda kernel
+ * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
+ * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
+ * eliminate entirely allocations happening within `search`:
+ * @code{.cpp}
+ * ...
+ * // Create a pooling memory resource with a pre-defined initial size.
+ * rmm::mr::pool_memory_resource mr(
+ * rmm::mr::get_current_device_resource(), 1024 * 1024);
+ * // use default search parameters
+ * ivf_flat::search_params search_params;
+ * // Use the same allocator across multiple searches to reduce the number of
+ * // cuda memory allocations
+ * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr);
+ * ivf_flat::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr);
+ * ivf_flat::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr);
+ * ...
+ * @endcode
+ * The exact size of the temporary buffer depends on multiple factors and is an implementation
+ * detail. However, you can safely specify a small initial size for the memory pool, so that only a
+ * few allocations happen to grow it during the first invocations of the `search`.
+ *
+ * @tparam value_t data element type
+ * @tparam idx_t type of the indices
+ * @tparam int_t precision / type of integral arguments
+ * @tparam matrix_idx_t matrix indexing type
+ *
+ * @param[in] handle
+ * @param[in] index ivf-flat constructed index
+ * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
+ * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
+ * [n_queries, k]
+ * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
+ * @param[in] params configure the search
+ * @param[in] k the number of neighbors to find for each query.
+ */
+template
+void search(const handle_t& handle,
+ const index& index,
+ raft::device_matrix_view queries,
+ raft::device_matrix_view neighbors,
+ raft::device_matrix_view distances,
+ const search_params& params,
+ int_t k)
+{
+ RAFT_EXPECTS(
+ queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
+ "Number of rows in output neighbors and distances matrices must equal the number of queries.");
+
+ RAFT_EXPECTS(
+ neighbors.extent(1) == distances.extent(1) && neighbors.extent(1) == static_cast(k),
+ "Number of columns in output neighbors and distances matrices must equal k");
+
+ RAFT_EXPECTS(queries.extent(1) == index.dim(),
+ "Number of query dimensions should equal number of dimensions in the index.");
+
+ return raft::spatial::knn::ivf_flat::detail::search(handle,
+ params,
+ index,
+ queries.data_handle(),
+ queries.extent(0),
+ k,
+ neighbors.data_handle(),
+ distances.data_handle(),
+ nullptr);
+}
+
+} // namespace raft::neighbors::ivf_flat
diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp
new file mode 100644
index 0000000000..c7e3798f5d
--- /dev/null
+++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp
@@ -0,0 +1,279 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "ann_types.hpp"
+
+#include
+#include
+#include
+#include
+
+#include
+
+namespace raft::neighbors::ivf_flat {
+
+/** Size of the interleaved group (see `index::data` description). */
+constexpr static uint32_t kIndexGroupSize = 32;
+
+struct index_params : ann::index_params {
+ /** The number of inverted lists (clusters) */
+ uint32_t n_lists = 1024;
+ /** The number of iterations searching for kmeans centers (index building). */
+ uint32_t kmeans_n_iters = 20;
+ /** The fraction of data to use during iterative kmeans building. */
+ double kmeans_trainset_fraction = 0.5;
+};
+
+struct search_params : ann::search_params {
+ /** The number of clusters to search. */
+ uint32_t n_probes = 20;
+};
+
+static_assert(std::is_aggregate_v);
+static_assert(std::is_aggregate_v);
+
+/**
+ * @brief IVF-flat index.
+ *
+ * @tparam T data element type
+ * @tparam IdxT type of the indices in the source dataset
+ *
+ */
+template
+struct index : ann::index {
+ static_assert(!raft::is_narrowing_v,
+ "IdxT must be able to represent all values of uint32_t");
+
+ public:
+ /**
+ * Vectorized load/store size in elements, determines the size of interleaved data chunks.
+ *
+ * TODO: in theory, we can lift this to the template parameter and keep it at hardware maximum
+ * possible value by padding the `dim` of the data https://github.com/rapidsai/raft/issues/711
+ */
+ [[nodiscard]] constexpr inline auto veclen() const noexcept -> uint32_t { return veclen_; }
+ /** Distance metric used for clustering. */
+ [[nodiscard]] constexpr inline auto metric() const noexcept -> raft::distance::DistanceType
+ {
+ return metric_;
+ }
+ /**
+ * Inverted list data [size, dim].
+ *
+ * The data consists of the dataset rows, grouped by their labels (into clusters/lists).
+ * Within each list (cluster), the data is grouped into blocks of `kIndexGroupSize` interleaved
+ * vectors. Note, the total index length is slightly larger than the source dataset length,
+ * because each cluster is padded by `kIndexGroupSize` elements.
+ *
+ * Interleaving pattern:
+ * within groups of `kIndexGroupSize` rows, the data is interleaved with the block size equal to
+ * `veclen * sizeof(T)`. That is, a chunk of `veclen` consecutive components of one row is
+ * followed by a chunk of the same size of the next row, and so on.
+ *
+ * __Example__: veclen = 2, dim = 6, kIndexGroupSize = 32, list_size = 31
+ *
+ * x[ 0, 0], x[ 0, 1], x[ 1, 0], x[ 1, 1], ... x[14, 0], x[14, 1], x[15, 0], x[15, 1],
+ * x[16, 0], x[16, 1], x[17, 0], x[17, 1], ... x[30, 0], x[30, 1], - , - ,
+ * x[ 0, 2], x[ 0, 3], x[ 1, 2], x[ 1, 3], ... x[14, 2], x[14, 3], x[15, 2], x[15, 3],
+ * x[16, 2], x[16, 3], x[17, 2], x[17, 3], ... x[30, 2], x[30, 3], - , - ,
+ * x[ 0, 4], x[ 0, 5], x[ 1, 4], x[ 1, 5], ... x[14, 4], x[14, 5], x[15, 4], x[15, 5],
+ * x[16, 4], x[16, 5], x[17, 4], x[17, 5], ... x[30, 4], x[30, 5], - , - ,
+ *
+ */
+ inline auto data() noexcept -> device_mdspan, row_major>
+ {
+ return data_.view();
+ }
+ [[nodiscard]] inline auto data() const noexcept
+ -> device_mdspan, row_major>
+ {
+ return data_.view();
+ }
+
+ /** Inverted list indices: ids of items in the source data [size] */
+ inline auto indices() noexcept -> device_mdspan, row_major>
+ {
+ return indices_.view();
+ }
+ [[nodiscard]] inline auto indices() const noexcept
+ -> device_mdspan, row_major>
+ {
+ return indices_.view();
+ }
+
+ /** Sizes of the lists (clusters) [n_lists] */
+ inline auto list_sizes() noexcept -> device_mdspan, row_major>
+ {
+ return list_sizes_.view();
+ }
+ [[nodiscard]] inline auto list_sizes() const noexcept
+ -> device_mdspan, row_major>
+ {
+ return list_sizes_.view();
+ }
+
+ /**
+ * Offsets into the lists [n_lists + 1].
+ * The last value contains the total length of the index.
+ */
+ inline auto list_offsets() noexcept -> device_mdspan, row_major>
+ {
+ return list_offsets_.view();
+ }
+ [[nodiscard]] inline auto list_offsets() const noexcept
+ -> device_mdspan, row_major>
+ {
+ return list_offsets_.view();
+ }
+
+ /** k-means cluster centers corresponding to the lists [n_lists, dim] */
+ inline auto centers() noexcept -> device_mdspan, row_major>
+ {
+ return centers_.view();
+ }
+ [[nodiscard]] inline auto centers() const noexcept
+ -> device_mdspan, row_major>
+ {
+ return centers_.view();
+ }
+
+ /**
+ * (Optional) Precomputed norms of the `centers` w.r.t. the chosen distance metric [n_lists].
+ *
+ * NB: this may be empty if the index is empty or if the metric does not require the center norms
+ * calculation.
+ */
+ inline auto center_norms() noexcept
+ -> std::optional, row_major>>
+ {
+ if (center_norms_.has_value()) {
+ return std::make_optional, row_major>>(
+ center_norms_->view());
+ } else {
+ return std::nullopt;
+ }
+ }
+ [[nodiscard]] inline auto center_norms() const noexcept
+ -> std::optional, row_major>>
+ {
+ if (center_norms_.has_value()) {
+ return std::make_optional, row_major>>(
+ center_norms_->view());
+ } else {
+ return std::nullopt;
+ }
+ }
+
+ /** Total length of the index. */
+ [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return indices_.extent(0); }
+ /** Dimensionality of the data. */
+ [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t
+ {
+ return centers_.extent(1);
+ }
+ /** Number of clusters/inverted lists. */
+ [[nodiscard]] constexpr inline auto n_lists() const noexcept -> uint32_t
+ {
+ return centers_.extent(0);
+ }
+
+ // Don't allow copying the index for performance reasons (try avoiding copying data)
+ index(const index&) = delete;
+ index(index&&) = default;
+ auto operator=(const index&) -> index& = delete;
+ auto operator=(index&&) -> index& = default;
+ ~index() = default;
+
+ /** Construct an empty index. It needs to be trained and then populated. */
+ index(const handle_t& handle, raft::distance::DistanceType metric, uint32_t n_lists, uint32_t dim)
+ : ann::index(),
+ veclen_(calculate_veclen(dim)),
+ metric_(metric),
+ data_(make_device_mdarray(handle, make_extents(0, dim))),
+ indices_(make_device_mdarray(handle, make_extents(0))),
+ list_sizes_(make_device_mdarray(handle, make_extents(n_lists))),
+ list_offsets_(make_device_mdarray(handle, make_extents(n_lists + 1))),
+ centers_(make_device_mdarray(handle, make_extents(n_lists, dim))),
+ center_norms_(std::nullopt)
+ {
+ check_consistency();
+ }
+
+ /** Construct an empty index. It needs to be trained and then populated. */
+ index(const handle_t& handle, const index_params& params, uint32_t dim)
+ : index(handle, params.metric, params.n_lists, dim)
+ {
+ }
+
+ /**
+ * Replace the content of the index with new uninitialized mdarrays to hold the indicated amount
+ * of data.
+ */
+ void allocate(const handle_t& handle, IdxT index_size, bool allocate_center_norms)
+ {
+ data_ = make_device_mdarray(handle, make_extents(index_size, dim()));
+ indices_ = make_device_mdarray(handle, make_extents(index_size));
+ center_norms_ =
+ allocate_center_norms
+ ? std::optional(make_device_mdarray(handle, make_extents(n_lists())))
+ : std::nullopt;
+ check_consistency();
+ }
+
+ private:
+ /**
+ * TODO: in theory, we can lift this to the template parameter and keep it at hardware maximum
+ * possible value by padding the `dim` of the data https://github.com/rapidsai/raft/issues/711
+ */
+ uint32_t veclen_;
+ raft::distance::DistanceType metric_;
+ device_mdarray, row_major> data_;
+ device_mdarray, row_major> indices_;
+ device_mdarray, row_major> list_sizes_;
+ device_mdarray, row_major> list_offsets_;
+ device_mdarray, row_major> centers_;
+ std::optional, row_major>> center_norms_;
+
+ /** Throw an error if the index content is inconsistent. */
+ void check_consistency()
+ {
+ RAFT_EXPECTS(dim() % veclen_ == 0, "dimensionality is not a multiple of the veclen");
+ RAFT_EXPECTS(data_.extent(0) == indices_.extent(0), "inconsistent index size");
+ RAFT_EXPECTS(data_.extent(1) == IdxT(centers_.extent(1)), "inconsistent data dimensionality");
+ RAFT_EXPECTS( //
+ (centers_.extent(0) == list_sizes_.extent(0)) && //
+ (centers_.extent(0) + 1 == list_offsets_.extent(0)) && //
+ (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)),
+ "inconsistent number of lists (clusters)");
+ RAFT_EXPECTS(reinterpret_cast(data_.data_handle()) % (veclen_ * sizeof(T)) == 0,
+ "The data storage pointer is not aligned to the vector length");
+ }
+
+ static auto calculate_veclen(uint32_t dim) -> uint32_t
+ {
+ // TODO: consider padding the dimensions and fixing veclen to its maximum possible value as a
+ // template parameter (https://github.com/rapidsai/raft/issues/711)
+ uint32_t veclen = 16 / sizeof(T);
+ while (dim % veclen != 0) {
+ veclen = veclen >> 1;
+ }
+ return veclen;
+ }
+};
+
+} // namespace raft::neighbors::ivf_flat
diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh
new file mode 100644
index 0000000000..1e32d5d7ba
--- /dev/null
+++ b/cpp/include/raft/neighbors/ivf_pq.cuh
@@ -0,0 +1,194 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "ivf_pq_types.hpp"
+#include
+#include
+
+#include
+
+#include
+#include
+
+namespace raft::neighbors::ivf_pq {
+
+/**
+ * @brief Build the index from the dataset for efficient search.
+ *
+ * NB: Currently, the following distance metrics are supported:
+ * - L2Expanded
+ * - L2Unexpanded
+ * - InnerProduct
+ *
+ * Usage example:
+ * @code{.cpp}
+ * using namespace raft::spatial::knn;
+ * // use default index parameters
+ * ivf_pq::index_params index_params;
+ * // create and fill the index from a [N, D] dataset
+ * auto index = ivf_pq::build(handle, index_params, dataset, N, D);
+ * // use default search parameters
+ * ivf_pq::search_params search_params;
+ * // search K nearest neighbours for each of the N queries
+ * ivf_pq::search(handle, search_params, index, queries, N, K, out_inds, out_dists);
+ * @endcode
+ *
+ * @tparam T data element type
+ * @tparam IdxT type of the indices in the source dataset
+ *
+ * @param handle
+ * @param params configure the index building
+ * @param[in] dataset a device pointer to a row-major matrix [n_rows, dim]
+ * @param n_rows the number of samples
+ * @param dim the dimensionality of the data
+ *
+ * @return the constructed ivf-pq index
+ */
+template
+inline auto build(
+ const handle_t& handle, const index_params& params, const T* dataset, IdxT n_rows, uint32_t dim)
+ -> index
+{
+ return raft::spatial::knn::ivf_pq::detail::build(handle, params, dataset, n_rows, dim);
+}
+
+/**
+ * @brief Build a new index containing the data of the original plus new extra vectors.
+ *
+ * Implementation note:
+ * The new data is clustered according to existing kmeans clusters, then the cluster
+ * centers are unchanged.
+ *
+ * Usage example:
+ * @code{.cpp}
+ * using namespace raft::spatial::knn;
+ * ivf_pq::index_params index_params;
+ * index_params.add_data_on_build = false; // don't populate index on build
+ * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training
+ * // train the index from a [N, D] dataset
+ * auto index_empty = ivf_pq::build(handle, index_params, dataset, N, D);
+ * // fill the index with the data
+ * auto index = ivf_pq::extend(handle, index_empty, dataset, nullptr, N);
+ * @endcode
+ *
+ * @tparam T data element type
+ * @tparam IdxT type of the indices in the source dataset
+ *
+ * @param handle
+ * @param orig_index original index
+ * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()]
+ * @param[in] new_indices a device pointer to a vector of indices [n_rows].
+ * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
+ * here to imply a continuous range `[0...n_rows)`.
+ * @param n_rows the number of samples
+ *
+ * @return the constructed extended ivf-pq index
+ */
+template
+inline auto extend(const handle_t& handle,
+ const index& orig_index,
+ const T* new_vectors,
+ const IdxT* new_indices,
+ IdxT n_rows) -> index
+{
+ return raft::spatial::knn::ivf_pq::detail::extend(
+ handle, orig_index, new_vectors, new_indices, n_rows);
+}
+
+/**
+ * @brief Extend the index with the new data.
+ * *
+ * @tparam T data element type
+ * @tparam IdxT type of the indices in the source dataset
+ *
+ * @param handle
+ * @param[inout] index
+ * @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()]
+ * @param[in] new_indices a device pointer to a vector of indices [n_rows].
+ * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
+ * here to imply a continuous range `[0...n_rows)`.
+ * @param n_rows the number of samples
+ */
+template
+inline void extend(const handle_t& handle,
+ index* index,
+ const T* new_vectors,
+ const IdxT* new_indices,
+ IdxT n_rows)
+{
+ *index = extend(handle, *index, new_vectors, new_indices, n_rows);
+}
+
+/**
+ * @brief Search ANN using the constructed index.
+ *
+ * See the [ivf_pq::build](#ivf_pq::build) documentation for a usage example.
+ *
+ * Note, this function requires a temporary buffer to store intermediate results between cuda kernel
+ * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
+ * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or
+ * eliminate entirely allocations happening within `search`:
+ * @code{.cpp}
+ * ...
+ * // Create a pooling memory resource with a pre-defined initial size.
+ * rmm::mr::pool_memory_resource mr(
+ * rmm::mr::get_current_device_resource(), 1024 * 1024);
+ * // use default search parameters
+ * ivf_pq::search_params search_params;
+ * // Use the same allocator across multiple searches to reduce the number of
+ * // cuda memory allocations
+ * ivf_pq::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr);
+ * ivf_pq::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr);
+ * ivf_pq::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr);
+ * ...
+ * @endcode
+ * The exact size of the temporary buffer depends on multiple factors and is an implementation
+ * detail. However, you can safely specify a small initial size for the memory pool, so that only a
+ * few allocations happen to grow it during the first invocations of the `search`.
+ *
+ * @tparam T data element type
+ * @tparam IdxT type of the indices
+ *
+ * @param handle
+ * @param params configure the search
+ * @param index ivf-pq constructed index
+ * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
+ * @param n_queries the batch size
+ * @param k the number of neighbors to find for each query.
+ * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
+ * [n_queries, k]
+ * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
+ * @param mr an optional memory resource to use across the searches (you can provide a large enough
+ * memory pool here to avoid memory allocations within search).
+ */
+template
+inline void search(const handle_t& handle,
+ const search_params& params,
+ const index& index,
+ const T* queries,
+ uint32_t n_queries,
+ uint32_t k,
+ IdxT* neighbors,
+ float* distances,
+ rmm::mr::device_memory_resource* mr = nullptr)
+{
+ return raft::spatial::knn::ivf_pq::detail::search(
+ handle, params, index, queries, n_queries, k, neighbors, distances, mr);
+}
+
+} // namespace raft::neighbors::ivf_pq
diff --git a/cpp/include/raft/neighbors/ivf_pq_types.hpp b/cpp/include/raft/neighbors/ivf_pq_types.hpp
new file mode 100644
index 0000000000..3dbf004e95
--- /dev/null
+++ b/cpp/include/raft/neighbors/ivf_pq_types.hpp
@@ -0,0 +1,434 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "ann_types.hpp"
+
+#include
+#include
+#include
+#include
+
+#include
+
+namespace raft::neighbors::ivf_pq {
+
+/** A type for specifying how PQ codebooks are created. */
+enum class codebook_gen { // NOLINT
+ PER_SUBSPACE = 0, // NOLINT
+ PER_CLUSTER = 1, // NOLINT
+};
+
+struct index_params : ann::index_params {
+ /**
+ * The number of inverted lists (clusters)
+ *
+ * Hint: the number of vectors per cluster (`n_rows/n_lists`) should be approximately 1,000 to
+ * 10,000.
+ */
+ uint32_t n_lists = 1024;
+ /** The number of iterations searching for kmeans centers (index building). */
+ uint32_t kmeans_n_iters = 20;
+ /** The fraction of data to use during iterative kmeans building. */
+ double kmeans_trainset_fraction = 0.5;
+ /**
+ * The bit length of the vector element after compression by PQ.
+ *
+ * Possible values: [4, 5, 6, 7, 8].
+ *
+ * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search
+ * performance, but the lower the recall.
+ */
+ uint32_t pq_bits = 8;
+ /**
+ * The dimensionality of the vector after compression by PQ. When zero, an optimal value is
+ * selected using a heuristic.
+ *
+ * NB: `pq_dim * pq_bits` must be a multiple of 8.
+ *
+ * Hint: a smaller 'pq_dim' results in a smaller index size and better search performance, but
+ * lower recall. If 'pq_bits' is 8, 'pq_dim' can be set to any number, but multiple of 8 are
+ * desirable for good performance. If 'pq_bits' is not 8, 'pq_dim' should be a multiple of 8.
+ * For good performance, it is desirable that 'pq_dim' is a multiple of 32. Ideally, 'pq_dim'
+ * should be also a divisor of the dataset dim.
+ */
+ uint32_t pq_dim = 0;
+ /** How PQ codebooks are created. */
+ codebook_gen codebook_kind = codebook_gen::PER_SUBSPACE;
+ /**
+ * Apply a random rotation matrix on the input data and queries even if `dim % pq_dim == 0`.
+ *
+ * Note: if `dim` is not multiple of `pq_dim`, a random rotation is always applied to the input
+ * data and queries to transform the working space from `dim` to `rot_dim`, which may be slightly
+ * larger than the original space and and is a multiple of `pq_dim` (`rot_dim % pq_dim == 0`).
+ * However, this transform is not necessary when `dim` is multiple of `pq_dim`
+ * (`dim == rot_dim`, hence no need in adding "extra" data columns / features).
+ *
+ * By default, if `dim == rot_dim`, the rotation transform is initialized with the identity
+ * matrix. When `force_random_rotation == true`, a random orthogonal transform matrix is generated
+ * regardless of the values of `dim` and `pq_dim`.
+ */
+ bool force_random_rotation = false;
+};
+
+struct search_params : ann::search_params {
+ /** The number of clusters to search. */
+ uint32_t n_probes = 20;
+ /**
+ * Data type of look up table to be created dynamically at search time.
+ *
+ * Possible values: [CUDA_R_32F, CUDA_R_16F, CUDA_R_8U]
+ *
+ * The use of low-precision types reduces the amount of shared memory required at search time, so
+ * fast shared memory kernels can be used even for datasets with large dimansionality. Note that
+ * the recall is slightly degraded when low-precision type is selected.
+ */
+ cudaDataType_t lut_dtype = CUDA_R_32F;
+ /**
+ * Storage data type for distance/similarity computed at search time.
+ *
+ * Possible values: [CUDA_R_16F, CUDA_R_32F]
+ *
+ * If the performance limiter at search time is device memory access, selecting FP16 will improve
+ * performance slightly.
+ */
+ cudaDataType_t internal_distance_dtype = CUDA_R_32F;
+ /**
+ * Thread block size of the distance calculation kernel at search time.
+ * When zero, an optimal block size is selected using a heuristic.
+ *
+ * Possible values: [0, 256, 512, 1024]
+ */
+ uint32_t preferred_thread_block_size = 0;
+};
+
+static_assert(std::is_aggregate_v);
+static_assert(std::is_aggregate_v);
+
+/**
+ * @brief IVF-PQ index.
+ *
+ * In the IVF-PQ index, a database vector y is approximated with two level quantization:
+ *
+ * y = Q_1(y) + Q_2(y - Q_1(y))
+ *
+ * The first level quantizer (Q_1), maps the vector y to the nearest cluster center. The number of
+ * clusters is n_lists.
+ *
+ * The second quantizer encodes the residual, and it is defined as a product quantizer [1].
+ *
+ * A product quantizer encodes a `dim` dimensional vector with a `pq_dim` dimensional vector.
+ * First we split the input vector into `pq_dim` subvectors (denoted by u), where each u vector
+ * contains `pq_len` distinct components of y
+ *
+ * y_1, y_2, ... y_{pq_len}, y_{pq_len+1}, ... y_{2*pq_len}, ... y_{dim-pq_len+1} ... y_{dim}
+ * \___________________/ \____________________________/ \______________________/
+ * u_1 u_2 u_{pq_dim}
+ *
+ * Then each subvector encoded with a separate quantizer q_i, end the results are concatenated
+ *
+ * Q_2(y) = q_1(u_1),q_2(u_2),...,q_{pq_dim}(u_pq_dim})
+ *
+ * Each quantizer q_i outputs a code with pq_bit bits. The second level quantizers are also defined
+ * by k-means clustering in the corresponding sub-space: the reproduction values are the centroids,
+ * and the set of reproduction values is the codebook.
+ *
+ * When the data dimensionality `dim` is not multiple of `pq_dim`, the feature space is transformed
+ * using a random orthogonal matrix to have `rot_dim = pq_dim * pq_len` dimensions
+ * (`rot_dim >= dim`).
+ *
+ * The second-level quantizers are trained either for each subspace or for each cluster:
+ * (a) codebook_gen::PER_SUBSPACE:
+ * creates `pq_dim` second-level quantizers - one for each slice of the data along features;
+ * (b) codebook_gen::PER_CLUSTER:
+ * creates `n_lists` second-level quantizers - one for each first-level cluster.
+ * In either case, the centroids are again found using k-means clustering interpreting the data as
+ * having pq_len dimensions.
+ *
+ * [1] Product quantization for nearest neighbor search Herve Jegou, Matthijs Douze, Cordelia Schmid
+ *
+ * @tparam IdxT type of the indices in the source dataset
+ *
+ */
+template
+struct index : ann::index {
+ static_assert(!raft::is_narrowing_v