Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CAGRA #1375

Merged
merged 47 commits into from
Apr 6, 2023
Merged

CAGRA #1375

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
8793adb
Add CAGRA, initial experimental version
tfeher Mar 27, 2023
2748758
Restructuring search params in progress
tfeher Mar 27, 2023
c51cc7a
replacing printf statements with RAFT_LOG_DEBUG
tfeher Mar 27, 2023
25d35ad
remove topk.cu
tfeher Mar 27, 2023
9adb9b0
Fix logging, revert some of the search_params refactoring
tfeher Mar 28, 2023
9dd0d46
adding specializations
tfeher Mar 29, 2023
d844e78
corrections
tfeher Mar 29, 2023
7c7819c
Enabled test for distance values, test team size
tfeher Mar 29, 2023
7991a56
added int8 and uint8 test and specializations
tfeher Mar 29, 2023
eb46fcf
correct copyright year for test files
tfeher Mar 29, 2023
60dfb3d
temporarily disabling int8 & uint8 tests
tfeher Mar 29, 2023
1be9514
Adding new search_plan
tfeher Mar 29, 2023
7e8ba3f
single_cta params factored out
tfeher Mar 29, 2023
e7cd010
Single cta plan creation works
tfeher Mar 30, 2023
72d2dff
all search configs added to plan
tfeher Mar 30, 2023
0e30822
refactored compiles
tfeher Mar 31, 2023
48a5161
Search dispatch refatored
tfeher Mar 31, 2023
af96473
Remove old search dispatch
tfeher Mar 31, 2023
b9639a3
Remove old search specialization
tfeher Mar 31, 2023
07c4607
Restoring CMakeLists
tfeher Mar 31, 2023
bc4ca55
Add serialization
tfeher Mar 31, 2023
078ce17
fix style error'
tfeher Mar 31, 2023
6beb3f6
Fixing topk size for small dataset
tfeher Apr 2, 2023
a27e9a7
Fix top-k size in search
tfeher Apr 2, 2023
a0ee761
Restore team size calculation
tfeher Apr 2, 2023
6ae6d32
Error message if dim is incompatible
tfeher Apr 2, 2023
41b45e7
extend tests
tfeher Apr 2, 2023
82f718e
Disabling multi_cta tests
tfeher Apr 2, 2023
ce99ead
Merge branch 'branch-23.04' into cagra_experimental
cjnolet Apr 2, 2023
38155ff
Replace CAGRA_HOST_DEVICE macros with RAFT_HOST_DEVICE
tfeher Apr 3, 2023
d5f0a00
Fix docstring
tfeher Apr 3, 2023
2b0a14e
Remove subsampling code from cagra_build
tfeher Apr 3, 2023
f0d9210
Fix metric type test
tfeher Apr 3, 2023
358c09c
Decrease the number of template specializations
tfeher Apr 4, 2023
64a898e
Reorder search params
tfeher Apr 4, 2023
e3639a6
Fix style errors
tfeher Apr 5, 2023
3983147
Fix style
tfeher Apr 5, 2023
c52dd33
Fix typo
tfeher Apr 5, 2023
7a249b5
Add resources arg to prune
tfeher Apr 5, 2023
4afb03e
Remove all cagra specializations
tfeher Apr 5, 2023
93c470a
Remove unused cagra specialization header
tfeher Apr 5, 2023
f962f22
Make refine_rate arg std::optional
tfeher Apr 5, 2023
ccbe925
Replace hashmap_mode string with enum
tfeher Apr 5, 2023
619666c
Only keep test file for float data type
tfeher Apr 5, 2023
1df4859
Add constxpr
tfeher Apr 5, 2023
c9be192
Remove constexpr
tfeher Apr 5, 2023
7f73aa9
Merge branch 'branch-23.04' into cagra_experimental
benfred Apr 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 226 additions & 0 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "detail/cagra/cagra_build.cuh"
#include "detail/cagra/cagra_search.cuh"
#include "detail/cagra/graph_core.cuh"

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/neighbors/cagra_types.hpp>
#include <rmm/cuda_stream_view.hpp>

namespace raft::neighbors::experimental::cagra {

/**
* @defgroup cagra CUDA ANN Graph-based nearest neighbor search
* @{
*/

/**
* @brief Build a kNN graph.
*
* The kNN graph is the first building block for CAGRA index.
* This function uses the IVF-PQ method to build a kNN graph.
*
* The output is a dense matrix that stores the neighbor indices for each pont in the dataset.
* Each point has the same number of neighbors.
*
* See [cagra::build](#cagra::build) for an alternative method.
*
* The following distance metrics are supported:
* - L2Expanded
*
* Usage example:
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* ivf_pq::index_params build_params;
* ivf_pq::search_params search_params
* auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 128);
* // create knn graph
* cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params);
* auto pruned_gaph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), 64);
* cagra::prune(res, dataset, knn_graph.view(), pruned_graph.view());
* // Construct an index from dataset and pruned knn_graph
* auto index = cagra::index<T, IdxT>(res, build_params.metric(), dataset, pruned_graph.view());
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param[in] res raft resources
* @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim]
* @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree]
* @param[in] refine_rate refinement rate for ivf-pq search
* @param[in] build_params (optional) ivf_pq index building parameters for knn graph
* @param[in] search_params (optional) ivf_pq search parameters
*/
template <typename DataT, typename IdxT, typename accessor>
void build_knn_graph(raft::device_resources const& res,
mdspan<const DataT, matrix_extent<IdxT>, row_major, accessor> dataset,
raft::host_matrix_view<IdxT, IdxT, row_major> knn_graph,
tfeher marked this conversation as resolved.
Show resolved Hide resolved
std::optional<float> refine_rate = std::nullopt,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
detail::build_knn_graph(res, dataset, knn_graph, refine_rate, build_params, search_params);
}

/**
* @brief Prune a KNN graph.
*
* Decrease the number of neighbors for each node.
*
* See [cagra::build_knn_graph](#cagra::build_knn_graph) for usage example
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param[in] res raft resources
* @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim]
* @param[in] knn_graph a matrix view (host or device) of the input knn graph [n_rows,
* knn_graph_degree]
* @param[out] new_graph a host matrix view of the pruned knn graph [n_rows, graph_degree]
*/
template <class DATA_T,
typename IdxT = uint32_t,
typename d_accessor =
host_device_accessor<std::experimental::default_accessor<DATA_T>, memory_type::device>,
typename g_accessor =
host_device_accessor<std::experimental::default_accessor<DATA_T>, memory_type::host>>
void prune(raft::device_resources const& res,
mdspan<const DATA_T, matrix_extent<IdxT>, row_major, d_accessor> dataset,
mdspan<IdxT, matrix_extent<IdxT>, row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, IdxT, row_major> new_graph)
{
detail::graph::prune(res, dataset, knn_graph, new_graph);
}

/**
* @brief Build the index from the dataset for efficient search.
*
* The build consist of two steps: build an intermediate knn-graph, and prune it to
* create the final graph. The index_params struct controls the node degree of these
* graphs.
*
* It is required that dataset and the pruned graph fit the GPU memory.
*
* To customize the parameters for knn-graph building and pruning, and to reuse the
* intermediate results, you could build the index in two steps using
* [cagra::build_knn_graph](#cagra::build_knn_graph) and [cagra::prune](#cagra::prune).
*
* The following distance metrics are supported:
* - L2
*
* Usage example:
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* cagra::index_params index_params;
* // create and fill the index from a [N, D] dataset
* auto index = cagra::build(res, index_params, dataset);
* // use default search parameters
* ivf_pq::search_params search_params;
* // search K nearest neighbours
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_device_matrix<float>(res, n_queries, k);
* ivf_pq::search(res, search_params, index, queries, neighbors, distances);
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param[in] res
* @param[in] params parameters for building the index
* @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim]
*
* @return the constructed cagra index
*/
template <typename T,
typename IdxT = uint32_t,
typename Accessor =
host_device_accessor<std::experimental::default_accessor<T>, memory_type::host>>
index<T, IdxT> build(raft::device_resources const& res,
const index_params& params,
mdspan<const T, matrix_extent<IdxT>, row_major, Accessor> dataset)
{
size_t degree = params.intermediate_graph_degree;
if (degree >= dataset.extent(0)) {
RAFT_LOG_WARN(
"Intermediate graph degree cannot be larger than dataset size, reducing it to %lu",
dataset.extent(0));
degree = dataset.extent(0) - 1;
}
RAFT_EXPECTS(degree >= params.graph_degree,
"Intermediate graph degree cannot be smaller than final graph degree");

auto knn_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), degree);

build_knn_graph(res, dataset, knn_graph.view());

auto cagra_graph = raft::make_host_matrix<IdxT, IdxT>(dataset.extent(0), params.graph_degree);

prune<T, IdxT>(res, dataset, knn_graph.view(), cagra_graph.view());

// Construct an index from dataset and pruned knn graph.
return index<T, IdxT>(res, params.metric, dataset, cagra_graph.view());
}

/**
* @brief Search ANN using the constructed index.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx cagra index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
void search(raft::device_resources const& res,
const search_params& params,
const index<T, IdxT>& idx,
raft::device_matrix_view<const T, IdxT, row_major> queries,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<float, IdxT, row_major> distances)
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must equal k");

RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

detail::search_main(res, params, idx, queries, neighbors, distances);
}
/** @} */ // end group cagra

} // namespace raft::neighbors::experimental::cagra
154 changes: 154 additions & 0 deletions cpp/include/raft/neighbors/cagra_serialize.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "detail/cagra/cagra_serialize.cuh"

namespace raft::neighbors::experimental::cagra {

/**
* \defgroup cagra_serialize CAGRA Serialize
* @{
*/

/**
* Write the index to an output stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // create an output stream
* std::ostream os(std::cout.rdbuf());
* // create an index with `auto index = cagra::build(...);`
* raft::serialize(handle, os, index);
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index CAGRA index
*
*/
template <typename T, typename IdxT>
void serialize(raft::device_resources const& handle, std::ostream& os, const index<T, IdxT>& index)
{
detail::serialize(handle, os, index);
}

/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* // create an index with `auto index = cagra::build(...);`
* raft::serialize(handle, filename, index);
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index CAGRA index
*
*/
template <typename T, typename IdxT>
void serialize(raft::device_resources const& handle,
const std::string& filename,
const index<T, IdxT>& index)
{
detail::serialize(handle, filename, index);
}

/**
* Load index from input stream
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // create an input stream
* std::istream is(std::cin.rdbuf());
* using T = float; // data element type
* using IdxT = int; // type of the index
* auto index = raft::deserialize<T, IdxT>(handle, is);
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] handle the raft handle
* @param[in] is input stream
*
* @return raft::neighbors::cagra::index<T, IdxT>
*/
template <typename T, typename IdxT>
index<T, IdxT> deserialize(raft::device_resources const& handle, std::istream& is)
{
return detail::deserialize<T, IdxT>(handle, is);
}

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <raft/core/device_resources.hpp>
*
* raft::device_resources handle;
*
* // create a string with a filepath
* std::string filename("/path/to/index");
* using T = float; // data element type
* using IdxT = int; // type of the index
* auto index = raft::deserialize<T, IdxT>(handle, filename);
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
*
* @return raft::neighbors::cagra::index<T, IdxT>
*/
template <typename T, typename IdxT>
index<T, IdxT> deserialize(raft::device_resources const& handle, const std::string& filename)
{
return detail::deserialize<T, IdxT>(handle, filename);
}

/**@}*/

} // namespace raft::neighbors::experimental::cagra
Loading