-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Random Ball Cover Algorithm for 2D Haversine/Euclidean (#213)
This PR is a proof of concept to use the triangle inequality to prune the tree of <img src="https://latex.codecogs.com/gif.latex?O(n^2)" title="O(n^2)" /> exhaustive distance computations into something smaller, such as on the order of <img src="https://latex.codecogs.com/gif.latex?O(c^{3/2}&space;*&space;\sqrt{n})" title="O(c^{3/2} * \sqrt{n})" /> where c is called an expansion constant, based on the dimensionality. This should (hopefully) be able to benefit both sparse and dense k-nearest neighbors and all algorithms that use them, hopefully providing a significant speedup for our sparse semirings primitive when only the k-nearest neighbors are desired. The goal here is to construct a tree out of the random ball cover algorithm such that we can utilize it in algorithms which would otherwise be able to make efficient use of a ball tree. However, there are additional challenges to this algorithm on the GPU, such as being able to batch the tree lookups. Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) - William Hicks (https://github.com/wphicks) - Dante Gama Dessavre (https://github.com/dantegd) URL: #213
- Loading branch information
Showing
21 changed files
with
2,645 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cstdint> | ||
|
||
#include <raft/linalg/distance_type.h> | ||
#include <thrust/transform.h> | ||
#include "ball_cover_common.h" | ||
#include "detail/ball_cover.cuh" | ||
#include "detail/ball_cover/common.cuh" | ||
|
||
namespace raft { | ||
namespace spatial { | ||
namespace knn { | ||
|
||
template <typename value_idx = std::int64_t, typename value_t, | ||
typename value_int = std::uint32_t> | ||
void rbc_build_index(const raft::handle_t &handle, | ||
BallCoverIndex<value_idx, value_t, value_int> &index) { | ||
ASSERT(index.n == 2, | ||
"Random ball cover currently only works in 2-dimensions"); | ||
if (index.metric == raft::distance::DistanceType::Haversine) { | ||
detail::rbc_build_index(handle, index, detail::HaversineFunc()); | ||
} else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || | ||
index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { | ||
detail::rbc_build_index(handle, index, detail::EuclideanFunc()); | ||
} else { | ||
RAFT_FAIL("Metric not support"); | ||
} | ||
|
||
index.set_index_trained(); | ||
} | ||
|
||
/** | ||
* Performs a faster exact knn in metric spaces using the triangle | ||
* inequality with a number of landmark points to reduce the | ||
* number of distance computations from O(n^2) to O(sqrt(n)). This | ||
* performs an all neighbors knn, which can reuse memory when | ||
* the index and query are the same array. This function will | ||
* build the index and assumes rbc_build_index() has not already | ||
* been called. | ||
* @tparam value_idx knn index type | ||
* @tparam value_t knn distance type | ||
* @tparam value_int type for integers, such as number of rows/cols | ||
* @param handle raft handle for resource management | ||
* @param index ball cover index which has not yet been built | ||
* @param k number of nearest neighbors to find | ||
* @param perform_post_filtering if this is false, only the closest k landmarks | ||
* are considered (which will return approximate | ||
* results). | ||
* @param[out] inds output knn indices | ||
* @param[out] dists output knn distances | ||
* @param weight a weight for overlap between the closest landmark and | ||
* the radius of other landmarks when pruning distances. | ||
* Setting this value below 1 can effectively turn off | ||
* computing distances against many other balls, enabling | ||
* approximate nearest neighbors. Recall can be adjusted | ||
* based on how many relevant balls are ignored. Note that | ||
* many datasets can still have great recall even by only | ||
* looking in the closest landmark. | ||
*/ | ||
template <typename value_idx = std::int64_t, typename value_t, | ||
typename value_int = std::uint32_t> | ||
void rbc_all_knn_query(const raft::handle_t &handle, | ||
BallCoverIndex<value_idx, value_t, value_int> &index, | ||
value_int k, value_idx *inds, value_t *dists, | ||
bool perform_post_filtering = true, float weight = 1.0) { | ||
ASSERT(index.n == 2, | ||
"Random ball cover currently only works in 2-dimensions"); | ||
if (index.metric == raft::distance::DistanceType::Haversine) { | ||
detail::rbc_all_knn_query(handle, index, k, inds, dists, | ||
detail::HaversineFunc(), perform_post_filtering, | ||
weight); | ||
} else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || | ||
index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { | ||
detail::rbc_all_knn_query(handle, index, k, inds, dists, | ||
detail::EuclideanFunc(), perform_post_filtering, | ||
weight); | ||
} else { | ||
RAFT_FAIL("Metric not supported"); | ||
} | ||
|
||
index.set_index_trained(); | ||
} | ||
|
||
/** | ||
* Performs a faster exact knn in metric spaces using the triangle | ||
* inequality with a number of landmark points to reduce the | ||
* number of distance computations from O(n^2) to O(sqrt(n)). This | ||
* function does not build the index and assumes rbc_build_index() has | ||
* already been called. Use this function when the index and | ||
* query arrays are different, otherwise use rbc_all_knn_query(). | ||
* @tparam value_idx index type | ||
* @tparam value_t distances type | ||
* @tparam value_int integer type for size info | ||
* @param handle raft handle for resource management | ||
* @param index ball cover index which has not yet been built | ||
* @param k number of nearest neighbors to find | ||
* @param query the | ||
* @param perform_post_filtering if this is false, only the closest k landmarks | ||
* are considered (which will return approximate | ||
* results). | ||
* @param[out] inds output knn indices | ||
* @param[out] dists output knn distances | ||
* @param weight a weight for overlap between the closest landmark and | ||
* the radius of other landmarks when pruning distances. | ||
* Setting this value below 1 can effectively turn off | ||
* computing distances against many other balls, enabling | ||
* approximate nearest neighbors. Recall can be adjusted | ||
* based on how many relevant balls are ignored. Note that | ||
* many datasets can still have great recall even by only | ||
* looking in the closest landmark. | ||
* @param k | ||
* @param inds | ||
* @param dists | ||
* @param n_samples | ||
*/ | ||
template <typename value_idx = std::int64_t, typename value_t, | ||
typename value_int = std::uint32_t> | ||
void rbc_knn_query(const raft::handle_t &handle, | ||
BallCoverIndex<value_idx, value_t, value_int> &index, | ||
value_int k, const value_t *query, value_int n_query_pts, | ||
value_idx *inds, value_t *dists, | ||
bool perform_post_filtering = true, float weight = 1.0) { | ||
ASSERT(index.n == 2, | ||
"Random ball cover currently only works in 2-dimensions"); | ||
if (index.metric == raft::distance::DistanceType::Haversine) { | ||
detail::rbc_knn_query(handle, index, k, query, n_query_pts, inds, dists, | ||
detail::HaversineFunc(), perform_post_filtering, | ||
weight); | ||
} else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || | ||
index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) { | ||
detail::rbc_knn_query(handle, index, k, query, n_query_pts, inds, dists, | ||
detail::EuclideanFunc(), perform_post_filtering, | ||
weight); | ||
} else { | ||
RAFT_FAIL("Metric not supported"); | ||
} | ||
} | ||
|
||
// TODO: implement functions for: | ||
// 4. rbc_eps_neigh() - given a populated index, perform query against different query array | ||
// 5. rbc_all_eps_neigh() - populate a BallCoverIndex and query against training data | ||
|
||
} // namespace knn | ||
} // namespace spatial | ||
} // namespace raft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <raft/linalg/distance_type.h> | ||
#include <cstdint> | ||
#include <raft/handle.hpp> | ||
#include <rmm/device_uvector.hpp> | ||
|
||
namespace raft { | ||
namespace spatial { | ||
namespace knn { | ||
|
||
/** | ||
* Stores raw index data points, sampled landmarks, the 1-nns of index points | ||
* to their closest landmarks, and the ball radii of each landmark. This | ||
* class is intended to be constructed once and reused across subsequent | ||
* queries. | ||
* @tparam value_idx | ||
* @tparam value_t | ||
* @tparam value_int | ||
*/ | ||
template <typename value_idx, typename value_t, | ||
typename value_int = std::uint32_t> | ||
class BallCoverIndex { | ||
public: | ||
explicit BallCoverIndex(const raft::handle_t &handle_, const value_t *X_, | ||
value_int m_, value_int n_, | ||
raft::distance::DistanceType metric_) | ||
: handle(handle_), | ||
X(X_), | ||
m(m_), | ||
n(n_), | ||
metric(metric_), | ||
/** | ||
* the sqrt() here makes the sqrt(m)^2 a linear-time lower bound | ||
* | ||
* Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m) | ||
*/ | ||
n_landmarks(sqrt(m_)), | ||
R_indptr(sqrt(m_) + 1, handle.get_stream()), | ||
R_1nn_cols(m_, handle.get_stream()), | ||
R_1nn_dists(m_, handle.get_stream()), | ||
R(sqrt(m_) * n_, handle.get_stream()), | ||
R_radius(sqrt(m_), handle.get_stream()), | ||
index_trained(false) {} | ||
|
||
value_idx *get_R_indptr() { return R_indptr.data(); } | ||
value_idx *get_R_1nn_cols() { return R_1nn_cols.data(); } | ||
value_t *get_R_1nn_dists() { return R_1nn_dists.data(); } | ||
value_t *get_R_radius() { return R_radius.data(); } | ||
value_t *get_R() { return R.data(); } | ||
const value_t *get_X() { return X; } | ||
|
||
bool is_index_trained() const { return index_trained; }; | ||
|
||
// This should only be set by internal functions | ||
void set_index_trained() { index_trained = true; } | ||
|
||
const raft::handle_t &handle; | ||
|
||
const value_int m; | ||
const value_int n; | ||
const value_int n_landmarks; | ||
|
||
const value_t *X; | ||
|
||
raft::distance::DistanceType metric; | ||
|
||
private: | ||
// CSR storing the neighborhoods for each data point | ||
rmm::device_uvector<value_idx> R_indptr; | ||
rmm::device_uvector<value_idx> R_1nn_cols; | ||
rmm::device_uvector<value_t> R_1nn_dists; | ||
|
||
rmm::device_uvector<value_t> R_radius; | ||
|
||
rmm::device_uvector<value_t> R; | ||
|
||
protected: | ||
bool index_trained; | ||
}; | ||
} // namespace knn | ||
} // namespace spatial | ||
} // namespace raft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.