Skip to content

Commit

Permalink
IVF-Flat Python wrappers (#1316)
Browse files Browse the repository at this point in the history
Add Python wrappers to IVF-Flat.

closes #1139 

It also adds a C++ interface to raft_runtime, that is called from Cython wrappers. The corresponding specializations are defined, but not used elsewhere (see #1238).

Authors:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Victor Lafargue (https://github.com/viclafargue)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Divye Gala (https://github.com/divyegala)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1316
  • Loading branch information
tfeher authored Mar 17, 2023
1 parent 8386807 commit 7074010
Show file tree
Hide file tree
Showing 32 changed files with 2,086 additions and 76 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ log
.DS_Store
dask-worker-space/
*.egg-info/
*.bin

## scikit-build
_skbuild
Expand Down
11 changes: 11 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,17 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/matrix/specializations/detail/select_k_float_int64_t.cu
src/distance/matrix/specializations/detail/select_k_half_uint32_t.cu
src/distance/matrix/specializations/detail/select_k_half_int64_t.cu
src/distance/neighbors/ivf_flat_search.cu
src/distance/neighbors/ivf_flat_build.cu
src/distance/neighbors/specializations/ivfflat_build_float_int64_t.cu
src/distance/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu
src/distance/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu
src/distance/neighbors/specializations/ivfflat_extend_float_int64_t.cu
src/distance/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu
src/distance/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu
src/distance/neighbors/specializations/ivfflat_search_float_int64_t.cu
src/distance/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu
src/distance/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu
src/distance/neighbors/ivfpq_build.cu
src/distance/neighbors/ivfpq_deserialize.cu
src/distance/neighbors/ivfpq_serialize.cu
Expand Down
94 changes: 69 additions & 25 deletions cpp/include/raft/neighbors/ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ auto build(raft::device_resources const& handle,
*/
template <typename value_t, typename idx_t>
auto build(raft::device_resources const& handle,
raft::device_matrix_view<const value_t, idx_t, row_major> dataset,
const index_params& params) -> index<value_t, idx_t>
const index_params& params,
raft::device_matrix_view<const value_t, idx_t, row_major> dataset)
-> index<value_t, idx_t>
{
return raft::neighbors::ivf_flat::detail::build(handle,
params,
Expand All @@ -119,6 +120,52 @@ auto build(raft::device_resources const& handle,
static_cast<idx_t>(dataset.extent(1)));
}

/**
* @brief Build the index from the dataset for efficient search.
*
* NB: Currently, the following distance metrics are supported:
* - L2Expanded
* - L2Unexpanded
* - InnerProduct
*
* Usage example:
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* ivf_flat::index_params index_params;
* // create and fill the index from a [N, D] dataset
* ivf_flat::index<decltype(dataset::value_type), decltype(dataset::index_type)> index;
* ivf_flat::build(handle, dataset, index_params, index);
* // use default search parameters
* ivf_flat::search_params search_params;
* // search K nearest neighbours for each of the N queries
* ivf_flat::search(handle, index, queries, out_inds, out_dists, search_params, k);
* @endcode
*
* @tparam value_t data element type
* @tparam idx_t type of the indices in the source dataset
* @tparam int_t precision / type of integral arguments
* @tparam matrix_idx_t matrix indexing type
*
* @param[in] handle
* @param[in] params configure the index building
* @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim]
* @param[out] idx reference to ivf_flat::index
*
*/
template <typename value_t, typename idx_t>
void build(raft::device_resources const& handle,
const index_params& params,
raft::device_matrix_view<const value_t, idx_t, row_major> dataset,
raft::neighbors::ivf_flat::index<value_t, idx_t>& idx)
{
idx = raft::neighbors::ivf_flat::detail::build(handle,
params,
dataset.data_handle(),
static_cast<idx_t>(dataset.extent(0)),
static_cast<idx_t>(dataset.extent(1)));
}

/** @} */

/**
Expand Down Expand Up @@ -192,20 +239,19 @@ auto extend(raft::device_resources const& handle,
* @tparam idx_t type of the indices in the source dataset
*
* @param[in] handle
* @param[in] orig_index original index
* @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device pointer to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr`
* @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices optional raft::device_matrix_view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt`
* here to imply a continuous range `[0...n_rows)`.
* @param[in] orig_index original index
*
* @return the constructed extended ivf-flat index
*/
template <typename value_t, typename idx_t>
auto extend(raft::device_resources const& handle,
const index<value_t, idx_t>& orig_index,
raft::device_matrix_view<const value_t, idx_t, row_major> new_vectors,
std::optional<raft::device_vector_view<const idx_t, idx_t>> new_indices = std::nullopt)
-> index<value_t, idx_t>
std::optional<raft::device_vector_view<const idx_t, idx_t>> new_indices,
const index<value_t, idx_t>& orig_index) -> index<value_t, idx_t>
{
return extend<value_t, idx_t>(
handle,
Expand Down Expand Up @@ -270,24 +316,25 @@ void extend(raft::device_resources const& handle,
* // train the index from a [N, D] dataset
* auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset);
* // fill the index with the data
* ivf_flat::extend(handle, index_empty, dataset);
* std::optional<raft::device_vector_view<const idx_t, idx_t>> no_op = std::nullopt;
* ivf_flat::extend(handle, dataset, no_opt, &index_empty);
* @endcode
*
* @tparam value_t data element type
* @tparam idx_t type of the indices in the source dataset
*
* @param[in] handle
* @param[inout] index
* @param[in] new_vectors a device pointer to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices a device pointer to a vector of indices [n_rows].
* @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices optional raft::device_matrix_view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt`
* here to imply a continuous range `[0...n_rows)`.
* @param[inout] index pointer to index, to be overwritten in-place
*/
template <typename value_t, typename idx_t>
void extend(raft::device_resources const& handle,
index<value_t, idx_t>* index,
raft::device_matrix_view<const value_t, idx_t, row_major> new_vectors,
std::optional<raft::device_vector_view<const idx_t, idx_t>> new_indices = std::nullopt)
std::optional<raft::device_vector_view<const idx_t, idx_t>> new_indices,
index<value_t, idx_t>* index)
{
extend(handle,
index,
Expand Down Expand Up @@ -386,30 +433,27 @@ void search(raft::device_resources const& handle,
* @tparam int_t precision / type of integral arguments
*
* @param[in] handle
* @param[in] params configure the search
* @param[in] index ivf-flat constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] params configure the search
* @param[in] k the number of neighbors to find for each query.
*/
template <typename value_t, typename idx_t, typename int_t>
template <typename value_t, typename idx_t>
void search(raft::device_resources const& handle,
const search_params& params,
const index<value_t, idx_t>& index,
raft::device_matrix_view<const value_t, idx_t, row_major> queries,
raft::device_matrix_view<idx_t, idx_t, row_major> neighbors,
raft::device_matrix_view<float, idx_t, row_major> distances,
const search_params& params,
int_t k)
raft::device_matrix_view<float, idx_t, row_major> distances)
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(
neighbors.extent(1) == distances.extent(1) && neighbors.extent(1) == static_cast<idx_t>(k),
"Number of columns in output neighbors and distances matrices must equal k");
RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1),
"Number of columns in output neighbors and distances matrices must be equal");

RAFT_EXPECTS(queries.extent(1) == index.dim(),
"Number of query dimensions should equal number of dimensions in the index.");
Expand All @@ -419,7 +463,7 @@ void search(raft::device_resources const& handle,
index,
queries.data_handle(),
static_cast<std::uint32_t>(queries.extent(0)),
static_cast<std::uint32_t>(k),
static_cast<std::uint32_t>(neighbors.extent(1)),
neighbors.data_handle(),
distances.data_handle(),
nullptr);
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/specializations.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/neighbors/specializations/ivf_flat.cuh>
#include <raft/neighbors/specializations/ivf_pq.cuh>
#include <raft/neighbors/specializations/refine.cuh>

Expand Down
54 changes: 54 additions & 0 deletions cpp/include/raft/neighbors/specializations/ivf_flat.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/neighbors/ivf_flat.cuh>

namespace raft::neighbors::ivf_flat {

#define RAFT_INST(T, IdxT) \
extern template auto build(raft::device_resources const& handle, \
const index_params& params, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset) \
->index<T, IdxT>; \
\
extern template auto extend( \
raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
const index<T, IdxT>& orig_index) \
->index<T, IdxT>; \
\
extern template void extend( \
raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx); \
\
extern template void search(raft::device_resources const&, \
raft::neighbors::ivf_flat::search_params const&, \
const raft::neighbors::ivf_flat::index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>);

RAFT_INST(float, uint64_t);
RAFT_INST(int8_t, uint64_t);
RAFT_INST(uint8_t, uint64_t);

#undef RAFT_INST
} // namespace raft::neighbors::ivf_flat
68 changes: 68 additions & 0 deletions cpp/include/raft_runtime/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/neighbors/ivf_flat_types.hpp>

namespace raft::runtime::neighbors::ivf_flat {

// We define overloads for build and extend with void return type. This is used in the Cython
// wrappers, where exception handling is not compatible with return type that has nontrivial
// constructor.
#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
auto build(raft::device_resources const& handle, \
const raft::neighbors::ivf_flat::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset) \
->raft::neighbors::ivf_flat::index<T, IdxT>; \
\
auto extend(raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
const raft::neighbors::ivf_flat::index<T, IdxT>& orig_index) \
->raft::neighbors::ivf_flat::index<T, IdxT>; \
\
void build(raft::device_resources const& handle, \
const raft::neighbors::ivf_flat::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset, \
raft::neighbors::ivf_flat::index<T, IdxT>& idx); \
\
void extend(raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx);

RAFT_INST_BUILD_EXTEND(float, int64_t)
RAFT_INST_BUILD_EXTEND(int8_t, int64_t)
RAFT_INST_BUILD_EXTEND(uint8_t, int64_t)

#undef RAFT_INST_BUILD_EXTEND

#define RAFT_INST_SEARCH(T, IdxT) \
void search(raft::device_resources const&, \
raft::neighbors::ivf_flat::search_params const&, \
raft::neighbors::ivf_flat::index<T, IdxT> const&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>);

RAFT_INST_SEARCH(float, int64_t);
RAFT_INST_SEARCH(int8_t, int64_t);
RAFT_INST_SEARCH(uint8_t, int64_t);

#undef RAFT_INST_SEARCH

} // namespace raft::runtime::neighbors::ivf_flat
Loading

0 comments on commit 7074010

Please sign in to comment.