Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mdspan-ifying raft::spatial #827

Merged
merged 38 commits into from
Oct 1, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
5d697b3
Breaking apart mdspan/mdarray into host_ and device_ variants
cjnolet Sep 7, 2022
0e6cc86
Updates
cjnolet Sep 7, 2022
71922f6
Fixing style
cjnolet Sep 7, 2022
b0e5a02
Separating host_span and device_span as well
cjnolet Sep 7, 2022
d69b163
Cleanup and getting to build
cjnolet Sep 7, 2022
50d750b
Updates
cjnolet Sep 7, 2022
6fda1fd
Fixing docs
cjnolet Sep 8, 2022
dded6ed
Updating readme to use proper header paths
cjnolet Sep 8, 2022
3aeb530
More updates based on review feedback
cjnolet Sep 8, 2022
ca354f3
Mdspanifying spatial/knn functions
cjnolet Sep 14, 2022
46b0750
Getting knn test to build
cjnolet Sep 14, 2022
a52bd9c
Merge branch 'branch-22.10' into fea-2210-mdspanified_knn
cjnolet Sep 14, 2022
76a469d
Merge branch 'branch-22.10' into imp-2210-host_device_mdspan
cjnolet Sep 14, 2022
b6c758c
Fixing style
cjnolet Sep 14, 2022
3838690
Merge branch 'imp-2210-host_device_mdspan' into fea-2210-mdspanified_knn
cjnolet Sep 15, 2022
105a3a6
Trying to FIND_RAFT_CPP on by default
cjnolet Sep 15, 2022
7eae6e3
Fixing bad merge
cjnolet Sep 15, 2022
019e358
Merge remote-tracking branch 'rapidsai/branch-22.10' into imp-2210-ho…
cjnolet Sep 15, 2022
08c5648
Merge branch 'imp-2210-host_device_mdspan' into fea-2210-mdspanified_knn
cjnolet Sep 15, 2022
141a2d1
Fixing knn wrapper
cjnolet Sep 19, 2022
6cd27af
mdspanidying random ball cover
cjnolet Sep 19, 2022
4870554
mdspan-ifying ivf_flat, rbc, and epsilon neighborhoods
cjnolet Sep 19, 2022
428a9e6
Fixing last compile error
cjnolet Sep 19, 2022
a612bc2
Updating ball cover specializations and API
cjnolet Sep 20, 2022
e63d121
Removing stream destroy from eps neigh tests
cjnolet Sep 20, 2022
d111743
Updating docs
cjnolet Sep 20, 2022
b1e834c
Merge branch 'branch-22.10' into fea-2210-mdspanified_knn
cjnolet Sep 22, 2022
bbf2ab1
Updates based on review feedback
cjnolet Sep 26, 2022
f528697
Updates based on review feedback
cjnolet Sep 26, 2022
90bbb33
Getting to nbuild
cjnolet Sep 26, 2022
8cb3ec5
Fixing style
cjnolet Sep 26, 2022
4c552bc
Adding weight back into rbc
cjnolet Sep 27, 2022
53a254e
Style check
cjnolet Sep 27, 2022
9939464
Updating docs to include [in] and [out]
cjnolet Sep 29, 2022
5363a2d
Removing defaults on template args
cjnolet Sep 30, 2022
b4377cb
Merge branch 'branch-22.10' into fea-2210-mdspanified_knn
cjnolet Sep 30, 2022
dec2f81
Adding limit-tests to build.sh. Removing default template args per
cjnolet Sep 30, 2022
5313cb9
Merge branch 'branch-22.10' into fea-2210-mdspanified_knn
cjnolet Sep 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 157 additions & 22 deletions cpp/include/raft/spatial/knn/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,28 @@ namespace raft {
namespace spatial {
namespace knn {

template <typename value_idx = std::int64_t, typename value_t, typename value_int = std::uint32_t>
/**
* Builds and populates a previously unbuilt BallCoverIndex
* @tparam idx_t knn index type
* @tparam value_t knn value type
* @tparam int_t integral type for knn params
* @tparam matrix_idx_t matrix indexing type
* @param handle library resource management handle
* @param index an empty (and not previous built) instance of BallCoverIndex
*/
template <typename idx_t = std::int64_t,
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
typename value_t,
typename int_t = std::uint32_t,
typename matrix_idx_t = std::uint32_t>
void rbc_build_index(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index)
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index)
{
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
if (index.metric == raft::distance::DistanceType::Haversine) {
detail::rbc_build_index(handle, index, detail::HaversineFunc<value_t, value_int>());
detail::rbc_build_index(handle, index, detail::HaversineFunc<value_t, int_t>());
} else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded ||
index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) {
detail::rbc_build_index(handle, index, detail::EuclideanFunc<value_t, value_int>());
detail::rbc_build_index(handle, index, detail::EuclideanFunc<value_t, int_t>());
} else {
RAFT_FAIL("Metric not support");
}
Expand All @@ -55,9 +67,9 @@ void rbc_build_index(const raft::handle_t& handle,
* the index and query are the same array. This function will
* build the index and assumes rbc_build_index() has not already
* been called.
* @tparam value_idx knn index type
* @tparam idx_t knn index type
* @tparam value_t knn distance type
* @tparam value_int type for integers, such as number of rows/cols
* @tparam int_t type for integers, such as number of rows/cols
* @param handle raft handle for resource management
* @param index ball cover index which has not yet been built
* @param k number of nearest neighbors to find
Expand All @@ -75,11 +87,14 @@ void rbc_build_index(const raft::handle_t& handle,
* many datasets can still have great recall even by only
* looking in the closest landmark.
*/
template <typename value_idx = std::int64_t, typename value_t, typename value_int = std::uint32_t>
template <typename idx_t = std::int64_t,
typename value_t,
typename int_t = std::uint32_t,
typename matrix_idx_t = std::uint32_t>
void rbc_all_knn_query(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index,
value_int k,
value_idx* inds,
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
int_t k,
idx_t* inds,
value_t* dists,
bool perform_post_filtering = true,
float weight = 1.0)
Expand All @@ -91,7 +106,7 @@ void rbc_all_knn_query(const raft::handle_t& handle,
k,
inds,
dists,
detail::HaversineFunc<value_t, value_int>(),
detail::HaversineFunc<value_t, int_t>(),
perform_post_filtering,
weight);
} else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded ||
Expand All @@ -101,7 +116,7 @@ void rbc_all_knn_query(const raft::handle_t& handle,
k,
inds,
dists,
detail::EuclideanFunc<value_t, value_int>(),
detail::EuclideanFunc<value_t, int_t>(),
perform_post_filtering,
weight);
} else {
Expand All @@ -111,16 +126,71 @@ void rbc_all_knn_query(const raft::handle_t& handle,
index.set_index_trained();
}

/**
* Performs a faster exact knn in metric spaces using the triangle
* inequality with a number of landmark points to reduce the
* number of distance computations from O(n^2) to O(sqrt(n)). This
* performs an all neighbors knn, which can reuse memory when
* the index and query are the same array. This function will
* build the index and assumes rbc_build_index() has not already
* been called.
* @tparam idx_t knn index type
* @tparam value_t knn distance type
* @tparam int_t type for integers, such as number of rows/cols
* @tparam matrix_idx_t matrix indexing type
* @param[in] handle raft handle for resource management
* @param[in] index ball cover index which has not yet been built
* @param[out] inds output knn indices
* @param[out] dists output knn distances
* @param[in] k number of nearest neighbors to find
* @param[in] perform_post_filtering if this is false, only the closest k landmarks
* are considered (which will return approximate
* results).
* @param[in] weight a weight for overlap between the closest landmark and
* the radius of other landmarks when pruning distances.
* Setting this value below 1 can effectively turn off
* computing distances against many other balls, enabling
* approximate nearest neighbors. Recall can be adjusted
* based on how many relevant balls are ignored. Note that
* many datasets can still have great recall even by only
* looking in the closest landmark.
*/
template <typename idx_t = std::int64_t,
typename value_t,
typename int_t = std::uint32_t,
typename matrix_idx_t = std::uint32_t>
void rbc_all_knn_query(const raft::handle_t& handle,
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
raft::device_matrix_view<idx_t, matrix_idx_t, row_major> inds,
raft::device_matrix_view<value_t, matrix_idx_t, row_major> dists,
int_t k = 5,
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
bool perform_post_filtering = true,
float weight = 1.0)
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
{
RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
RAFT_EXPECTS(k <= index.m,
"k must be less than or equal to the number of data points in the index");
RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast<idx_t>(k),
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
"Number of columns in output indices and distances matrices must be equal to k");

RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == index.get_X().extent(0),
"Number of rows in output indices and distances matrices must equal number of rows "
"in index matrix.");

rbc_all_knn_query(
handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, weight);
}

/**
* Performs a faster exact knn in metric spaces using the triangle
* inequality with a number of landmark points to reduce the
* number of distance computations from O(n^2) to O(sqrt(n)). This
* function does not build the index and assumes rbc_build_index() has
* already been called. Use this function when the index and
* query arrays are different, otherwise use rbc_all_knn_query().
* @tparam value_idx index type
* @tparam idx_t index type
* @tparam value_t distances type
* @tparam value_int integer type for size info
* @tparam int_t integer type for size info
* @param handle raft handle for resource management
* @param index ball cover index which has not yet been built
* @param k number of nearest neighbors to find
Expand All @@ -130,7 +200,7 @@ void rbc_all_knn_query(const raft::handle_t& handle,
* results).
* @param[out] inds output knn indices
* @param[out] dists output knn distances
* @param weight a weight for overlap between the closest landmark and
* @param[in] weight a weight for overlap between the closest landmark and
* the radius of other landmarks when pruning distances.
* Setting this value below 1 can effectively turn off
* computing distances against many other balls, enabling
Expand All @@ -140,13 +210,13 @@ void rbc_all_knn_query(const raft::handle_t& handle,
* looking in the closest landmark.
* @param[in] n_query_pts number of query points
*/
template <typename value_idx = std::int64_t, typename value_t, typename value_int = std::uint32_t>
template <typename idx_t = std::int64_t, typename value_t, typename int_t = std::uint32_t>
void rbc_knn_query(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index,
value_int k,
BallCoverIndex<idx_t, value_t, int_t>& index,
int_t k,
const value_t* query,
value_int n_query_pts,
value_idx* inds,
int_t n_query_pts,
idx_t* inds,
value_t* dists,
bool perform_post_filtering = true,
float weight = 1.0)
Expand All @@ -160,7 +230,7 @@ void rbc_knn_query(const raft::handle_t& handle,
n_query_pts,
inds,
dists,
detail::HaversineFunc<value_t, value_int>(),
detail::HaversineFunc<value_t, int_t>(),
perform_post_filtering,
weight);
} else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded ||
Expand All @@ -172,14 +242,79 @@ void rbc_knn_query(const raft::handle_t& handle,
n_query_pts,
inds,
dists,
detail::EuclideanFunc<value_t, value_int>(),
detail::EuclideanFunc<value_t, int_t>(),
perform_post_filtering,
weight);
} else {
RAFT_FAIL("Metric not supported");
}
}

/**
* Performs a faster exact knn in metric spaces using the triangle
* inequality with a number of landmark points to reduce the
* number of distance computations from O(n^2) to O(sqrt(n)). This
* function does not build the index and assumes rbc_build_index() has
* already been called. Use this function when the index and
* query arrays are different, otherwise use rbc_all_knn_query().
* @tparam idx_t index type
* @tparam value_t distances type
* @tparam int_t integer type for size info
* @tparam matrix_idx_t
* @param[in] handle raft handle for resource management
* @param[in] index ball cover index which has not yet been built
* @param[in] query device matrix containing query data points
* @param[out] inds output knn indices
* @param[out] dists output knn distances
* @param[in] k number of nearest neighbors to find
* @param[in] perform_post_filtering if this is false, only the closest k landmarks
* are considered (which will return approximate
* results).
* @param[in] weight a weight for overlap between the closest landmark and
* the radius of other landmarks when pruning distances.
* Setting this value below 1 can effectively turn off
* computing distances against many other balls, enabling
* approximate nearest neighbors. Recall can be adjusted
* based on how many relevant balls are ignored. Note that
* many datasets can still have great recall even by only
* looking in the closest landmark.
*/
template <typename idx_t = std::int64_t,
typename value_t,
typename int_t = std::uint32_t,
typename matrix_idx_t = std::uint32_t>
void rbc_knn_query(const raft::handle_t& handle,
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> query,
raft::device_matrix_view<idx_t, matrix_idx_t, row_major> inds,
raft::device_matrix_view<value_t, matrix_idx_t, row_major> dists,
int_t k = 5,
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
bool perform_post_filtering = true,
float weight = 1.0)
{
RAFT_EXPECTS(k <= index.m,
"k must be less than or equal to the number of data points in the index");
RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast<idx_t>(k),
"Number of columns in output indices and distances matrices must be equal to k");

RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == query.extent(0),
"Number of rows in output indices and distances matrices must equal number of rows "
"in search matrix.");

RAFT_EXPECTS(query.extent(1) == index.get_R().extent(1),
"Number of columns in query and index matrices must match.");

rbc_knn_query(handle,
index,
k,
query.data_handle(),
query.extent(0),
inds.data_handle(),
dists.data_handle(),
perform_post_filtering,
weight);
}

// TODO: implement functions for:
// 4. rbc_eps_neigh() - given a populated index, perform query against different query array
// 5. rbc_all_eps_neigh() - populate a BallCoverIndex and query against training data
Expand Down
Loading