Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add brute_force index serialization #2036

Merged
merged 22 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
589b8ae
Update brute-force API to accept params struct
wphicks Nov 30, 2023
4050682
Update brute force extern template declarations
wphicks Dec 1, 2023
3844863
Begin adding brute-force index serialization
wphicks Dec 2, 2023
606f5e3
Add test header for brute force index
wphicks Dec 3, 2023
e47117e
Add filename serialization for brute force
wphicks Dec 3, 2023
26da060
Add brute force index tests
wphicks Dec 3, 2023
e0f63c8
Fix brute force explicit instantiations
wphicks Dec 3, 2023
07bd3d3
Fix dim serialization type
wphicks Dec 3, 2023
550263b
Merge branch 'branch-24.02' into fea-bf_ser
wphicks Dec 3, 2023
1deb340
Update cpp/include/raft/neighbors/brute_force_serialize.cuh
wphicks Dec 4, 2023
35e2ba6
Update cpp/include/raft/neighbors/brute_force_serialize.cuh
wphicks Dec 4, 2023
aee5e38
Use devArrMatchKnnPair for brute force index evaluation
wphicks Dec 4, 2023
3ce2091
Correct handling of exact test match
wphicks Dec 4, 2023
b65226f
Apply suggestions from code review
wphicks Dec 4, 2023
7476aff
Merge branch 'branch-24.02' into fea-bf_ser
wphicks Dec 4, 2023
ce12715
Remove unused host variables
wphicks Dec 4, 2023
0287c94
Make dataset serialization optional
wphicks Dec 5, 2023
e975ce2
Add tests for no-dataset serialization
wphicks Dec 5, 2023
820f26a
Merge branch 'branch-24.02' into fea-bf_ser
wphicks Dec 6, 2023
21cd1da
Document include_dataset parameter
wphicks Dec 7, 2023
e83a044
Merge remote-tracking branch 'refs/remotes/origin/fea-bf_ser' into fe…
wphicks Dec 7, 2023
f8c1f3c
Merge branch 'branch-24.02' into fea-bf_ser
wphicks Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ INSTALL_TARGET=install
BUILD_REPORT_METRICS=""
BUILD_REPORT_INCL_CACHE_STATS=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;NEIGHBORS_TEST;NEIGHBORS_ANN_BRUTE_FORCE_TEST;NEIGHBORS_ANN_CAGRA_TEST;NEIGHBORS_ANN_NN_DESCENT_TEST;NEIGHBORS_ANN_IVF_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NEIGHBORS_TEST;STATS_TEST;UTILS_TEST"
BENCH_TARGETS="CLUSTER_BENCH;CORE_BENCH;NEIGHBORS_BENCH;DISTANCE_BENCH;LINALG_BENCH;MATRIX_BENCH;SPARSE_BENCH;RANDOM_BENCH"

CACHE_ARGS=""
Expand Down Expand Up @@ -323,6 +323,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then
if [[ $CMAKE_TARGET == *"CLUSTER_TEST"* || \
$CMAKE_TARGET == *"DISTANCE_TEST"* || \
$CMAKE_TARGET == *"MATRIX_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_BRUTE_FORCE_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_CAGRA_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_IVF_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_ANN_NN_DESCENT_TEST"* || \
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/linalg/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "detail/norm.cuh"
#include "linalg_types.hpp"
#include <raft/core/mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>

#include <raft/core/device_mdspan.hpp>
Expand Down Expand Up @@ -154,4 +155,4 @@ void norm(raft::resources const& handle,
}; // end namespace linalg
}; // end namespace raft

#endif
#endif
45 changes: 45 additions & 0 deletions cpp/include/raft/neighbors/brute_force-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,21 @@ index<T> build(raft::resources const& res,
raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
T metric_arg = 0.0) RAFT_EXPLICIT;

template <typename T, typename Accessor>
index<T> build(raft::resources const& res,
index_params const& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset) RAFT_EXPLICIT;

template <typename T, typename IdxT>
void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<T, int64_t, row_major> distances) RAFT_EXPLICIT;

template <typename T, typename IdxT>
void search(raft::resources const& res,
search_params const& params,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
Expand Down Expand Up @@ -116,18 +129,50 @@ extern template void search<float, int>(
raft::device_matrix_view<int, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);

extern template void search<float, int>(
raft::resources const& res,
search_params const& params,
const raft::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);

extern template void search<float, int64_t>(
raft::resources const& res,
const raft::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int64_t, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);

extern template void search<float, int64_t>(
raft::resources const& res,
search_params const& params,
const raft::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int64_t, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);

extern template raft::neighbors::brute_force::index<float> build<float>(
raft::resources const& res,
raft::device_matrix_view<const float, int64_t, row_major> dataset,
raft::distance::DistanceType metric,
float metric_arg);

extern template raft::neighbors::brute_force::index<float> build<float>(
raft::resources const& res,
index_params const& params,
raft::device_matrix_view<const float, int64_t, row_major> dataset);

extern template raft::neighbors::brute_force::index<float> build<float>(
raft::resources const& res,
raft::host_matrix_view<const float, int64_t, row_major> dataset,
raft::distance::DistanceType metric,
float metric_arg);

extern template raft::neighbors::brute_force::index<float> build<float>(
raft::resources const& res,
index_params const& params,
raft::host_matrix_view<const float, int64_t, row_major> dataset);
} // namespace raft::neighbors::brute_force

#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \
Expand Down
63 changes: 61 additions & 2 deletions cpp/include/raft/neighbors/brute_force-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/core/copy.cuh>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/distance/distance_types.hpp>
Expand Down Expand Up @@ -341,21 +342,33 @@ index<T> build(raft::resources const& res,
// certain distance metrics can benefit by pre-calculating the norms for the index dataset
// which lets us avoid calculating these at query time
std::optional<device_vector<T, int64_t>> norms;
// TODO(wphicks): Replace once mdbuffer is available
auto dataset_storage = std::optional<device_matrix<T, int64_t>>{};
auto dataset_view = [&res, &dataset_storage, dataset]() {
if constexpr (std::is_same_v<decltype(dataset),
raft::device_matrix_view<const T, int64_t, row_major>>) {
return dataset;
} else {
dataset_storage = make_device_matrix<T, int64_t>(res, dataset.extent(0), dataset.extent(1));
raft::copy(res, dataset_storage->view(), dataset);
return raft::make_const_mdspan(dataset_storage->view());
}
}();
if (metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded ||
metric == raft::distance::DistanceType::CosineExpanded) {
norms = make_device_vector<T, int64_t>(res, dataset.extent(0));
// cosine needs the l2norm, where as l2 distances needs the squared norm
if (metric == raft::distance::DistanceType::CosineExpanded) {
raft::linalg::norm(res,
dataset,
dataset_view,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS,
raft::sqrt_op{});
} else {
raft::linalg::norm(res,
dataset,
dataset_view,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS);
Expand All @@ -365,6 +378,25 @@ index<T> build(raft::resources const& res,
return index<T>(res, dataset, std::move(norms), metric, metric_arg);
}

/**
* @brief Build the index from the dataset for efficient search.
*
* @tparam T data element type
*
* @param[in] res
* @param[in] params configure the index building
* @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim]
*
* @return the constructed brute force index
*/
template <typename T, typename Accessor>
index<T> build(raft::resources const& res,
index_params const& params,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset)
{
return build<T, Accessor>(res, dataset, params.metric, float(params.metric_arg));
}

/**
* @brief Brute Force search using the constructed index.
*
Expand All @@ -390,5 +422,32 @@ void search(raft::resources const& res,
{
raft::neighbors::detail::brute_force_search<T, IdxT>(res, idx, queries, neighbors, distances);
}

/**
* @brief Brute Force search using the constructed index.
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
* @param[in] res raft resources
* @param[in] params configure the search
* @param[in] idx brute force index
* @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
*/
template <typename T, typename IdxT>
void search(raft::resources const& res,
search_params const& params,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<T, int64_t, row_major> distances)
{
raft::neighbors::detail::brute_force_search<T, IdxT>(res, idx, queries, neighbors, distances);
}

/** @} */ // end group brute_force_knn
} // namespace raft::neighbors::brute_force
Loading