Skip to content

Commit

Permalink
Add brute_force index serialization (#2036)
Browse files Browse the repository at this point in the history
Add serialization and deserialization methods for brute_force index. Also add overloads to brute_force search and build functions taking index_param and search_param arguments for API compatibility with other index types.

Authors:
  - William Hicks (https://github.com/wphicks)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2036
  • Loading branch information
wphicks authored Dec 7, 2023
1 parent bc35f2b commit d2210a2
Show file tree
Hide file tree
Showing 10 changed files with 688 additions and 16 deletions.
3 changes: 2 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ INSTALL_TARGET=install
BUILD_REPORT_METRICS=""
BUILD_REPORT_INCL_CACHE_STATS=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_BRUTE_FORCE_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"

CACHE_ARGS=""
Expand Down Expand Up @@ -323,6 +323,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then
if [[ $CMAKE_TARGET == *"CLUSTER_TEST"* || \
$CMAKE_TARGET == *"DISTANCE_TEST"* || \
$CMAKE_TARGET == *"MATRIX_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_BRUTE_FORCE_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/linalg/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "detail/norm.cuh"
#include "linalg_types.hpp"
#include <raft/core/mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>

#include <raft/core/device_mdspan.hpp>
Expand Down Expand Up @@ -154,4 +155,4 @@ void norm(raft::resources const& handle,
}; // end namespace linalg
}; // end namespace raft

#endif
#endif
45 changes: 45 additions & 0 deletions cpp/include/raft/neighbors/brute_force-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,21 @@ index<T> build(raft::resources const& res,
raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
T metric_arg = 0.0) RAFT_EXPLICIT;

template <typename T, typename Accessor>
index<T> build(raft::resources const& res,
index_params const& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset) RAFT_EXPLICIT;

template <typename T, typename IdxT>
void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<T, int64_t, row_major> distances) RAFT_EXPLICIT;

template <typename T, typename IdxT>
void search(raft::resources const& res,
search_params const& params,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
Expand Down Expand Up @@ -116,18 +129,50 @@ extern template void search<float, int>(
raft::device_matrix_view<int, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);

extern template void search<float, int>(
raft::resources const& res,
search_params const& params,
const raft::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);

extern template void search<float, int64_t>(
raft::resources const& res,
const raft::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int64_t, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);

extern template void search<float, int64_t>(
raft::resources const& res,
search_params const& params,
const raft::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int64_t, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);

extern template raft::neighbors::brute_force::index<float> build<float>(
raft::resources const& res,
raft::device_matrix_view<const float, int64_t, row_major> dataset,
raft::distance::DistanceType metric,
float metric_arg);

extern template raft::neighbors::brute_force::index<float> build<float>(
raft::resources const& res,
index_params const& params,
raft::device_matrix_view<const float, int64_t, row_major> dataset);

extern template raft::neighbors::brute_force::index<float> build<float>(
raft::resources const& res,
raft::host_matrix_view<const float, int64_t, row_major> dataset,
raft::distance::DistanceType metric,
float metric_arg);

extern template raft::neighbors::brute_force::index<float> build<float>(
raft::resources const& res,
index_params const& params,
raft::host_matrix_view<const float, int64_t, row_major> dataset);
} // namespace raft::neighbors::brute_force

#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \
Expand Down
63 changes: 61 additions & 2 deletions cpp/include/raft/neighbors/brute_force-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/core/copy.cuh>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/distance/distance_types.hpp>
Expand Down Expand Up @@ -341,21 +342,33 @@ index<T> build(raft::resources const& res,
// certain distance metrics can benefit by pre-calculating the norms for the index dataset
// which lets us avoid calculating these at query time
std::optional<device_vector<T, int64_t>> norms;
// TODO(wphicks): Replace once mdbuffer is available
auto dataset_storage = std::optional<device_matrix<T, int64_t>>{};
auto dataset_view = [&res, &dataset_storage, dataset]() {
if constexpr (std::is_same_v<decltype(dataset),
raft::device_matrix_view<const T, int64_t, row_major>>) {
return dataset;
} else {
dataset_storage = make_device_matrix<T, int64_t>(res, dataset.extent(0), dataset.extent(1));
raft::copy(res, dataset_storage->view(), dataset);
return raft::make_const_mdspan(dataset_storage->view());
}
}();
if (metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded ||
metric == raft::distance::DistanceType::CosineExpanded) {
norms = make_device_vector<T, int64_t>(res, dataset.extent(0));
// cosine needs the l2norm, where as l2 distances needs the squared norm
if (metric == raft::distance::DistanceType::CosineExpanded) {
raft::linalg::norm(res,
dataset,
dataset_view,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS,
raft::sqrt_op{});
} else {
raft::linalg::norm(res,
dataset,
dataset_view,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS);
Expand All @@ -365,6 +378,25 @@ index<T> build(raft::resources const& res,
return index<T>(res, dataset, std::move(norms), metric, metric_arg);
}

/**
* @brief Build the index from the dataset for efficient search.
*
* @tparam T data element type
*
* @param[in] res
* @param[in] params configure the index building
* @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim]
*
* @return the constructed brute force index
*/
template <typename T, typename Accessor>
index<T> build(raft::resources const& res,
index_params const& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset)
{
return build<T, Accessor>(res, dataset, params.metric, float(params.metric_arg));
}

/**
* @brief Brute Force search using the constructed index.
*
Expand All @@ -390,5 +422,32 @@ void search(raft::resources const& res,
{
raft::neighbors::detail::brute_force_search<T, IdxT>(res, idx, queries, neighbors, distances);
}

/**
* @brief Brute Force search using the constructed index.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx brute force index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
void search(raft::resources const& res,
search_params const& params,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<T, int64_t, row_major> distances)
{
raft::neighbors::detail::brute_force_search<T, IdxT>(res, idx, queries, neighbors, distances);
}

/** @} */ // end group brute_force_knn
} // namespace raft::neighbors::brute_force
Loading

0 comments on commit d2210a2

Please sign in to comment.