Skip to content

Commit

Permalink
Mdspan-ifying raft::spatial (#827)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Mark Hoemmen (https://github.com/mhoemmen)
  - Divye Gala (https://github.com/divyegala)

URL: #827
  • Loading branch information
cjnolet authored Oct 1, 2022
1 parent 8a31ae6 commit 2bd8d04
Show file tree
Hide file tree
Showing 15 changed files with 942 additions and 340 deletions.
25 changes: 22 additions & 3 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ ARGS=$*
REPODIR=$(cd $(dirname $0); pwd)

VALIDARGS="clean libraft pylibraft raft-dask docs tests bench clean -v -g --install --compile-libs --compile-nn --compile-dist --allgpuarch --no-nvtx --show_depr_warn -h --buildfaiss --minimal-deps"
HELP="$0 [<target> ...] [<flag> ...] [--cmake-args=\"<args>\"] [--cache-tool=<tool>]
HELP="$0 [<target> ...] [<flag> ...] [--cmake-args=\"<args>\"] [--cache-tool=<tool>] [--limit-tests=<targets>]
where <target> is:
clean - remove all existing build artifacts and configuration (start over)
libraft - build the raft C++ code only. Also builds the C-wrapper library
Expand All @@ -40,6 +40,7 @@ HELP="$0 [<target> ...] [<flag> ...] [--cmake-args=\"<args>\"] [--cache-tool=<to
the only option to be supported)
--minimal-deps - disables dependencies like thrust so they can be overridden.
can be useful for a pure header-only install
--limit-tests - semicolon-separated list of test executables to compile (e.g. TEST_SPATIAL;TEST_CLUSTER)
--allgpuarch - build for all supported GPU architectures
--buildfaiss - build faiss statically into raft
--install - install cmake targets
Expand All @@ -50,7 +51,7 @@ HELP="$0 [<target> ...] [<flag> ...] [--cmake-args=\"<args>\"] [--cache-tool=<to
to speedup the build process.
-h - print this text
default action (no args) is to build both libraft and raft-dask targets
default action (no args) is to build libraft, tests, pylibraft and raft-dask targets
"
LIBRAFT_BUILD_DIR=${LIBRAFT_BUILD_DIR:=${REPODIR}/cpp/build}
SPHINX_BUILD_DIR=${REPODIR}/docs
Expand All @@ -70,6 +71,8 @@ COMPILE_NN_LIBRARY=OFF
COMPILE_DIST_LIBRARY=OFF
ENABLE_NN_DEPENDENCIES=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NN_TEST;SPATIAL_TEST;STATS_TEST;UTILS_TEST"

ENABLE_thrust_DEPENDENCY=ON

CACHE_ARGS=""
Expand Down Expand Up @@ -136,6 +139,21 @@ function cacheTool {
fi
}

function limitTests {
# Check for option to limit the set of test binaries to build
if [[ -n $(echo $ARGS | { grep -E "\-\-limit\-tests" || true; } ) ]]; then
# There are possible weird edge cases that may cause this regex filter to output nothing and fail silently
# the true pipe will catch any weird edge cases that may happen and will cause the program to fall back
# on the invalid option error
LIMIT_TEST_TARGETS=$(echo $ARGS | sed -e 's/.*--limit-tests=//' -e 's/ .*//')
if [[ -n ${LIMIT_TEST_TARGETS} ]]; then
# Remove the full LIMIT_TEST_TARGETS argument from list of args so that it passes validArgs function
ARGS=${ARGS//--limit-tests=$LIMIT_TEST_TARGETS/}
TEST_TARGETS=${LIMIT_TEST_TARGETS}
fi
fi
}

if hasArg -h || hasArg --help; then
echo "${HELP}"
exit 0
Expand All @@ -145,6 +163,7 @@ fi
if (( ${NUMARGS} != 0 )); then
cmakeArgs
cacheTool
limitTests
for a in ${ARGS}; do
if ! (echo " ${VALIDARGS} " | grep -q " ${a} "); then
echo "Invalid option: ${a}"
Expand Down Expand Up @@ -194,7 +213,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then
COMPILE_DIST_LIBRARY=ON
ENABLE_NN_DEPENDENCIES=ON
COMPILE_NN_LIBRARY=ON
CMAKE_TARGET="${CMAKE_TARGET};CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NN_TEST;SPATIAL_TEST;STATS_TEST;UTILS_TEST"
CMAKE_TARGET="${CMAKE_TARGET};${TEST_TARGETS}"
fi

if hasArg bench || (( ${NUMARGS} == 0 )); then
Expand Down
187 changes: 155 additions & 32 deletions cpp/include/raft/spatial/knn/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,25 @@ namespace raft {
namespace spatial {
namespace knn {

template <typename value_idx = std::int64_t, typename value_t, typename value_int = std::uint32_t>
/**
* Builds and populates a previously unbuilt BallCoverIndex
* @tparam idx_t knn index type
* @tparam value_t knn value type
* @tparam int_t integral type for knn params
* @tparam matrix_idx_t matrix indexing type
* @param[in] handle library resource management handle
* @param[inout] index an empty (and not previous built) instance of BallCoverIndex
*/
template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_t>
void rbc_build_index(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index)
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index)
{
ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
if (index.metric == raft::distance::DistanceType::Haversine) {
detail::rbc_build_index(handle, index, detail::HaversineFunc<value_t, value_int>());
detail::rbc_build_index(handle, index, detail::HaversineFunc<value_t, int_t>());
} else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded ||
index.metric == raft::distance::DistanceType::L2SqrtUnexpanded) {
detail::rbc_build_index(handle, index, detail::EuclideanFunc<value_t, value_int>());
detail::rbc_build_index(handle, index, detail::EuclideanFunc<value_t, int_t>());
} else {
RAFT_FAIL("Metric not support");
}
Expand All @@ -55,18 +64,18 @@ void rbc_build_index(const raft::handle_t& handle,
* the index and query are the same array. This function will
* build the index and assumes rbc_build_index() has not already
* been called.
* @tparam value_idx knn index type
* @tparam idx_t knn index type
* @tparam value_t knn distance type
* @tparam value_int type for integers, such as number of rows/cols
* @param handle raft handle for resource management
* @param index ball cover index which has not yet been built
* @param k number of nearest neighbors to find
* @param perform_post_filtering if this is false, only the closest k landmarks
* @tparam int_t type for integers, such as number of rows/cols
* @param[in] handle raft handle for resource management
* @param[inout] index ball cover index which has not yet been built
* @param[in] k number of nearest neighbors to find
* @param[in] perform_post_filtering if this is false, only the closest k landmarks
* are considered (which will return approximate
* results).
* @param[out] inds output knn indices
* @param[out] dists output knn distances
* @param weight a weight for overlap between the closest landmark and
* @param[in] weight a weight for overlap between the closest landmark and
* the radius of other landmarks when pruning distances.
* Setting this value below 1 can effectively turn off
* computing distances against many other balls, enabling
Expand All @@ -75,11 +84,11 @@ void rbc_build_index(const raft::handle_t& handle,
* many datasets can still have great recall even by only
* looking in the closest landmark.
*/
template <typename value_idx = std::int64_t, typename value_t, typename value_int = std::uint32_t>
template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_t>
void rbc_all_knn_query(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index,
value_int k,
value_idx* inds,
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
int_t k,
idx_t* inds,
value_t* dists,
bool perform_post_filtering = true,
float weight = 1.0)
Expand All @@ -91,7 +100,7 @@ void rbc_all_knn_query(const raft::handle_t& handle,
k,
inds,
dists,
detail::HaversineFunc<value_t, value_int>(),
detail::HaversineFunc<value_t, int_t>(),
perform_post_filtering,
weight);
} else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded ||
Expand All @@ -101,7 +110,7 @@ void rbc_all_knn_query(const raft::handle_t& handle,
k,
inds,
dists,
detail::EuclideanFunc<value_t, value_int>(),
detail::EuclideanFunc<value_t, int_t>(),
perform_post_filtering,
weight);
} else {
Expand All @@ -111,26 +120,78 @@ void rbc_all_knn_query(const raft::handle_t& handle,
index.set_index_trained();
}

/**
* Performs a faster exact knn in metric spaces using the triangle
* inequality with a number of landmark points to reduce the
* number of distance computations from O(n^2) to O(sqrt(n)). This
* performs an all neighbors knn, which can reuse memory when
* the index and query are the same array. This function will
* build the index and assumes rbc_build_index() has not already
* been called.
* @tparam idx_t knn index type
* @tparam value_t knn distance type
* @tparam int_t type for integers, such as number of rows/cols
* @tparam matrix_idx_t matrix indexing type
* @param[in] handle raft handle for resource management
* @param[in] index ball cover index which has not yet been built
* @param[out] inds output knn indices
* @param[out] dists output knn distances
* @param[in] k number of nearest neighbors to find
* @param[in] perform_post_filtering if this is false, only the closest k landmarks
* are considered (which will return approximate
* results).
* @param[in] weight a weight for overlap between the closest landmark and
* the radius of other landmarks when pruning distances.
* Setting this value below 1 can effectively turn off
* computing distances against many other balls, enabling
* approximate nearest neighbors. Recall can be adjusted
* based on how many relevant balls are ignored. Note that
* many datasets can still have great recall even by only
* looking in the closest landmark.
*/
template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_t>
void rbc_all_knn_query(const raft::handle_t& handle,
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
raft::device_matrix_view<idx_t, matrix_idx_t, row_major> inds,
raft::device_matrix_view<value_t, matrix_idx_t, row_major> dists,
int_t k,
bool perform_post_filtering = true,
float weight = 1.0)
{
RAFT_EXPECTS(index.n <= 3, "only 2d and 3d vectors are supported in current implementation");
RAFT_EXPECTS(k <= index.m,
"k must be less than or equal to the number of data points in the index");
RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast<matrix_idx_t>(k),
"Number of columns in output indices and distances matrices must be equal to k");

RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == index.get_X().extent(0),
"Number of rows in output indices and distances matrices must equal number of rows "
"in index matrix.");

rbc_all_knn_query(
handle, index, k, inds.data_handle(), dists.data_handle(), perform_post_filtering, weight);
}

/**
* Performs a faster exact knn in metric spaces using the triangle
* inequality with a number of landmark points to reduce the
* number of distance computations from O(n^2) to O(sqrt(n)). This
* function does not build the index and assumes rbc_build_index() has
* already been called. Use this function when the index and
* query arrays are different, otherwise use rbc_all_knn_query().
* @tparam value_idx index type
* @tparam idx_t index type
* @tparam value_t distances type
* @tparam value_int integer type for size info
* @param handle raft handle for resource management
* @param index ball cover index which has not yet been built
* @param k number of nearest neighbors to find
* @param query the
* @param perform_post_filtering if this is false, only the closest k landmarks
* @tparam int_t integer type for size info
* @param[in] handle raft handle for resource management
* @param[inout] index ball cover index which has not yet been built
* @param[in] k number of nearest neighbors to find
* @param[in] query the
* @param[in] perform_post_filtering if this is false, only the closest k landmarks
* are considered (which will return approximate
* results).
* @param[out] inds output knn indices
* @param[out] dists output knn distances
* @param weight a weight for overlap between the closest landmark and
* @param[in] weight a weight for overlap between the closest landmark and
* the radius of other landmarks when pruning distances.
* Setting this value below 1 can effectively turn off
* computing distances against many other balls, enabling
Expand All @@ -140,13 +201,13 @@ void rbc_all_knn_query(const raft::handle_t& handle,
* looking in the closest landmark.
* @param[in] n_query_pts number of query points
*/
template <typename value_idx = std::int64_t, typename value_t, typename value_int = std::uint32_t>
template <typename idx_t, typename value_t, typename int_t>
void rbc_knn_query(const raft::handle_t& handle,
BallCoverIndex<value_idx, value_t, value_int>& index,
value_int k,
BallCoverIndex<idx_t, value_t, int_t>& index,
int_t k,
const value_t* query,
value_int n_query_pts,
value_idx* inds,
int_t n_query_pts,
idx_t* inds,
value_t* dists,
bool perform_post_filtering = true,
float weight = 1.0)
Expand All @@ -160,7 +221,7 @@ void rbc_knn_query(const raft::handle_t& handle,
n_query_pts,
inds,
dists,
detail::HaversineFunc<value_t, value_int>(),
detail::HaversineFunc<value_t, int_t>(),
perform_post_filtering,
weight);
} else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded ||
Expand All @@ -172,14 +233,76 @@ void rbc_knn_query(const raft::handle_t& handle,
n_query_pts,
inds,
dists,
detail::EuclideanFunc<value_t, value_int>(),
detail::EuclideanFunc<value_t, int_t>(),
perform_post_filtering,
weight);
} else {
RAFT_FAIL("Metric not supported");
}
}

/**
* Performs a faster exact knn in metric spaces using the triangle
* inequality with a number of landmark points to reduce the
* number of distance computations from O(n^2) to O(sqrt(n)). This
* function does not build the index and assumes rbc_build_index() has
* already been called. Use this function when the index and
* query arrays are different, otherwise use rbc_all_knn_query().
* @tparam idx_t index type
* @tparam value_t distances type
* @tparam int_t integer type for size info
* @tparam matrix_idx_t
* @param[in] handle raft handle for resource management
* @param[in] index ball cover index which has not yet been built
* @param[in] query device matrix containing query data points
* @param[out] inds output knn indices
* @param[out] dists output knn distances
* @param[in] k number of nearest neighbors to find
* @param[in] perform_post_filtering if this is false, only the closest k landmarks
* are considered (which will return approximate
* results).
* @param[in] weight a weight for overlap between the closest landmark and
* the radius of other landmarks when pruning distances.
* Setting this value below 1 can effectively turn off
* computing distances against many other balls, enabling
* approximate nearest neighbors. Recall can be adjusted
* based on how many relevant balls are ignored. Note that
* many datasets can still have great recall even by only
* looking in the closest landmark.
*/
template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_t>
void rbc_knn_query(const raft::handle_t& handle,
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> query,
raft::device_matrix_view<idx_t, matrix_idx_t, row_major> inds,
raft::device_matrix_view<value_t, matrix_idx_t, row_major> dists,
int_t k,
bool perform_post_filtering = true,
float weight = 1.0)
{
RAFT_EXPECTS(k <= index.m,
"k must be less than or equal to the number of data points in the index");
RAFT_EXPECTS(inds.extent(1) == dists.extent(1) && dists.extent(1) == static_cast<idx_t>(k),
"Number of columns in output indices and distances matrices must be equal to k");

RAFT_EXPECTS(inds.extent(0) == dists.extent(0) && dists.extent(0) == query.extent(0),
"Number of rows in output indices and distances matrices must equal number of rows "
"in search matrix.");

RAFT_EXPECTS(query.extent(1) == index.get_R().extent(1),
"Number of columns in query and index matrices must match.");

rbc_knn_query(handle,
index,
k,
query.data_handle(),
query.extent(0),
inds.data_handle(),
dists.data_handle(),
perform_post_filtering,
weight);
}

// TODO: implement functions for:
// 4. rbc_eps_neigh() - given a populated index, perform query against different query array
// 5. rbc_all_eps_neigh() - populate a BallCoverIndex and query against training data
Expand Down
Loading

0 comments on commit 2bd8d04

Please sign in to comment.