Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paying down some tech debt on docs, runtime API, and cython #1055

Merged
merged 39 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a7822ff
Starting to include specific function overloads
cjnolet Nov 18, 2022
dbbeb73
Fixing compile error
cjnolet Nov 18, 2022
40742ac
Merge remote-tracking branch 'rapidsai/branch-22.12' into doc-2212-re…
cjnolet Nov 18, 2022
09bdf84
Fixing up docs for pylibraft
cjnolet Nov 18, 2022
c52c4c3
Fixing compile error
cjnolet Nov 18, 2022
a045ece
More updates. Adding headers for many of the files
cjnolet Nov 23, 2022
7260bb5
Follow-on fixes
cjnolet Nov 30, 2022
247586c
Merge branch 'branch-23.02' into doc-2212-remove_broken_docs
cjnolet Nov 30, 2022
742cce8
Separating some of the distance APIs into groups
cjnolet Nov 30, 2022
9eec82e
Fixes
cjnolet Nov 30, 2022
ba08b57
Fixing pydoc for pylibraft code examples
cjnolet Dec 1, 2022
f511446
Adding more grouping
cjnolet Dec 1, 2022
6756d8b
Creating and using doxygen groups for matrix and linalg. Stats to come
cjnolet Dec 1, 2022
b8d262a
Merge branch 'branch-22.12' into doc-2212-remove_broken_docs
cjnolet Dec 1, 2022
56cf720
Merge branch 'branch-23.02' into doc-2212-remove_broken_docs
cjnolet Dec 1, 2022
802e7fa
Separating stats into doxygen groups
cjnolet Dec 2, 2022
e70bdc5
Fixing some broken doxygen groupds
cjnolet Dec 2, 2022
3a2e72c
Fix v measure
cjnolet Dec 2, 2022
47fe901
Breaking down some of the categories to make docs easier to consume
cjnolet Dec 2, 2022
1962cd9
Removing random state from datagen category
cjnolet Dec 2, 2022
7766a0b
Another fix
cjnolet Dec 2, 2022
59d4583
Consolidating mdspan cython defitions
cjnolet Dec 2, 2022
6899197
Removing unneeded factory function added during troubleshooting
cjnolet Dec 2, 2022
4a315a8
Removing more unused stuff
cjnolet Dec 2, 2022
da8ad8d
Removing typedefs
cjnolet Dec 2, 2022
16b7b9e
Update cpp/include/raft/stats/kl_divergence.cuh
cjnolet Dec 2, 2022
74efe0f
Update docs/source/cpp_api/mdspan_mdspan.rst
cjnolet Dec 2, 2022
ef4e1c7
Moving a bunch of stuff around
cjnolet Dec 2, 2022
794c157
Removing things from nn cmake
cjnolet Dec 2, 2022
c9d1b7b
Merge branch 'imp-2302-consolidate_mdspan_pxd' into doc-2212-remove_b…
cjnolet Dec 2, 2022
2ad0216
Fixing test syntax
cjnolet Dec 2, 2022
7a8689b
Merge branch 'doc-2212-remove_broken_docs' of github.com:cjnolet/raft…
cjnolet Dec 2, 2022
5f95808
Fixing build
cjnolet Dec 2, 2022
bcb57af
Fixing build
cjnolet Dec 2, 2022
a6feb71
Fixing bad merge
cjnolet Dec 2, 2022
d253981
Fixing build
cjnolet Dec 2, 2022
b325274
Merge branch 'imp-2302-consolidate_mdspan_pxd' into doc-2212-remove_b…
cjnolet Dec 2, 2022
64e2ab5
Rng state
cjnolet Dec 3, 2022
02100b1
Fixing another pxd error
cjnolet Dec 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 75 additions & 90 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -278,81 +278,83 @@ set_target_properties(raft_distance PROPERTIES EXPORT_NAME distance)
if(RAFT_COMPILE_DIST_LIBRARY)
add_library(
raft_distance_lib
src/distance/pairwise_distance.cu
src/distance/fused_l2_min_arg.cu
src/distance/update_centroids_float.cu
src/distance/update_centroids_double.cu
src/distance/cluster_cost_float.cu
src/distance/cluster_cost_double.cu
src/distance/kmeans_fit_float.cu
src/distance/kmeans_fit_double.cu
src/distance/specializations/detail/canberra.cu
src/distance/specializations/detail/chebyshev.cu
src/distance/specializations/detail/correlation.cu
src/distance/specializations/detail/cosine.cu
src/distance/specializations/detail/cosine.cu
src/distance/specializations/detail/hamming_unexpanded.cu
src/distance/specializations/detail/hellinger_expanded.cu
src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu
src/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu
src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu
src/distance/specializations/detail/kernels/gram_matrix_base_double.cu
src/distance/specializations/detail/kernels/gram_matrix_base_float.cu
src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu
src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu
src/distance/distance/pairwise_distance.cu
src/distance/distance/fused_l2_min_arg.cu
src/distance/cluster/update_centroids_float.cu
src/distance/cluster/update_centroids_double.cu
src/distance/cluster/cluster_cost_float.cu
src/distance/cluster/cluster_cost_double.cu
src/distance/neighbors/refine.cu
src/distance/neighbors/ivfpq_search.cu
src/distance/cluster/kmeans_fit_float.cu
src/distance/cluster/kmeans_fit_double.cu
src/distance/distance/specializations/detail/canberra.cu
src/distance/distance/specializations/detail/chebyshev.cu
src/distance/distance/specializations/detail/correlation.cu
src/distance/distance/specializations/detail/cosine.cu
src/distance/distance/specializations/detail/cosine.cu
src/distance/distance/specializations/detail/hamming_unexpanded.cu
src/distance/distance/specializations/detail/hellinger_expanded.cu
src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu
src/distance/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu
src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu
src/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu
src/distance/distance/specializations/detail/kernels/gram_matrix_base_float.cu
src/distance/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu
src/distance/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu
# These are somehow missing a kernel definition which is causing a compile error.
# src/distance/specializations/detail/kernels/rbf_kernel_double.cu
# src/distance/specializations/detail/kernels/rbf_kernel_float.cu
src/distance/specializations/detail/kernels/tanh_kernel_double.cu
src/distance/specializations/detail/kernels/tanh_kernel_float.cu
src/distance/specializations/detail/kl_divergence_float_float_float_int.cu
src/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu
src/distance/specializations/detail/kl_divergence_double_double_double_int.cu
src/distance/specializations/detail/l1_float_float_float_int.cu
src/distance/specializations/detail/l1_float_float_float_uint32.cu
src/distance/specializations/detail/l1_double_double_double_int.cu
src/distance/specializations/detail/l2_expanded_float_float_float_int.cu
src/distance/specializations/detail/l2_expanded_float_float_float_uint32.cu
src/distance/specializations/detail/l2_expanded_double_double_double_int.cu
src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu
src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_uint32.cu
src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_uint32.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/l2_unexpanded_float_float_float_uint32.cu
src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu
src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/russel_rao_double_double_double_int.cu
src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu
src/distance/specializations/detail/russel_rao_float_float_float_int.cu
src/distance/specializations/fused_l2_nn_double_int.cu
src/distance/specializations/fused_l2_nn_double_int64.cu
src/distance/specializations/fused_l2_nn_float_int.cu
src/distance/specializations/fused_l2_nn_float_int64.cu
src/nn/specializations/detail/ivfpq_build.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_search.cu
src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu
src/nn/specializations/refine.cu
src/random/specializations/rmat_rectangular_generator_int_double.cu
src/random/specializations/rmat_rectangular_generator_int64_double.cu
src/random/specializations/rmat_rectangular_generator_int_float.cu
src/random/specializations/rmat_rectangular_generator_int64_float.cu
src/distance/distance/specializations/detail/kernels/tanh_kernel_double.cu
src/distance/distance/specializations/detail/kernels/tanh_kernel_float.cu
src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu
src/distance/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu
src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu
src/distance/distance/specializations/detail/l1_float_float_float_int.cu
src/distance/distance/specializations/detail/l1_float_float_float_uint32.cu
src/distance/distance/specializations/detail/l1_double_double_double_int.cu
src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu
src/distance/distance/specializations/detail/l2_expanded_float_float_float_uint32.cu
src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu
src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu
src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_uint32.cu
src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu
src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu
src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_uint32.cu
src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu
src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu
src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_uint32.cu
src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu
src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu
src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu
src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu
src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu
src/distance/distance/specializations/detail/russel_rao_float_float_float_uint32.cu
src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu
src/distance/distance/specializations/fused_l2_nn_double_int.cu
src/distance/distance/specializations/fused_l2_nn_double_int64.cu
src/distance/distance/specializations/fused_l2_nn_float_int.cu
src/distance/distance/specializations/fused_l2_nn_float_int64.cu
src/distance/neighbors/ivfpq_build.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_fast.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_half_fast.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
src/distance/neighbors/specializations/detail/ivfpq_search_float_int64_t.cu
src/distance/neighbors/specializations/detail/ivfpq_search_float_uint64_t.cu
src/distance/neighbors/specializations/detail/ivfpq_search_float_uint32_t.cu
src/distance/random/rmat_rectangular_generator_int_double.cu
src/distance/random/rmat_rectangular_generator_int64_double.cu
src/distance/random/rmat_rectangular_generator_int_float.cu
src/distance/random/rmat_rectangular_generator_int64_float.cu
)
set_target_properties(
raft_distance_lib
Expand Down Expand Up @@ -410,23 +412,6 @@ if(RAFT_COMPILE_NN_LIBRARY)
src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu
src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu
src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_build.cu
src/nn/specializations/detail/ivfpq_search.cu
src/nn/specializations/detail/ivfpq_search_float_int64_t.cu
src/nn/specializations/detail/ivfpq_search_float_uint32_t.cu
src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu
src/nn/specializations/fused_l2_knn_long_float_true.cu
src/nn/specializations/fused_l2_knn_long_float_false.cu
src/nn/specializations/fused_l2_knn_int_float_true.cu
Expand Down Expand Up @@ -519,7 +504,7 @@ if(TARGET raft_distance_lib)
EXPORT raft-distance-lib-exports
)
install(
DIRECTORY include/raft_distance
DIRECTORY include/raft_runtime
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT distance
)
Expand Down
143 changes: 85 additions & 58 deletions cpp/include/raft/distance/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@

#include <raft/core/device_mdspan.hpp>

namespace raft {
namespace distance {

/**
* @defgroup pairwise_distance pairwise distance prims
* @defgroup pairwise_distance pointer-based pairwise distance prims
* @{
*/

namespace raft {
namespace distance {

/**
* @brief Evaluate pairwise distances with the user epilogue lamba allowed
* @tparam DistanceType which distance to evaluate
Expand Down Expand Up @@ -219,58 +219,6 @@ void distance(const InType* x,
x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor, metric_arg);
}

/**
* @brief Evaluate pairwise distances for the simple use case.
*
* Note: Only contiguous row- or column-major layouts supported currently.
*
* @tparam DistanceType which distance to evaluate
* @tparam InType input argument type
* @tparam AccType accumulation type
* @tparam OutType output type
* @tparam Index_ Index type
* @param handle raft handle for managing expensive resources
* @param x first set of points (size n*k)
* @param y second set of points (size m*k)
* @param dist output distance matrix (size n*m)
* @param metric_arg metric argument (used for Minkowski distance)
*/
template <raft::distance::DistanceType distanceType,
typename InType,
typename AccType,
typename OutType,
typename layout = raft::layout_c_contiguous,
typename Index_ = int>
void distance(raft::handle_t const& handle,
raft::device_matrix_view<InType, Index_, layout> const x,
raft::device_matrix_view<InType, Index_, layout> const y,
raft::device_matrix_view<OutType, Index_, layout> dist,
InType metric_arg = 2.0f)
{
RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal.");
RAFT_EXPECTS(dist.extent(0) == x.extent(0),
"Number of rows in output must be equal to "
"number of rows in X");
RAFT_EXPECTS(dist.extent(1) == y.extent(0),
"Number of columns in output must be equal to "
"number of rows in Y");

RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous.");
RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous.");

constexpr auto is_rowmajor = std::is_same_v<layout, layout_c_contiguous>;

distance<distanceType, InType, AccType, OutType, Index_>(x.data_handle(),
y.data_handle(),
dist.data_handle(),
x.extent(0),
y.extent(0),
x.extent(1),
handle.get_stream(),
is_rowmajor,
metric_arg);
}

/**
* @brief Convenience wrapper around 'distance' prim to convert runtime metric
* into compile time for the purpose of dispatch
Expand Down Expand Up @@ -401,6 +349,85 @@ void pairwise_distance(const raft::handle_t& handle,
handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg);
}

/** @} */

/**
* \defgroup distance_mdspan Pairwise distance functions
* @{
*/

/**
* @brief Evaluate pairwise distances for the simple use case.
*
* Note: Only contiguous row- or column-major layouts supported currently.
*
* Usage example:
* @code{.cpp}
* #include <raft/core/handle.hpp>
* #include <raft/core/device_mdarray.hpp>
* #include <raft/random/make_blobs.cuh>
* #include <raft/distance/distance.cuh>
*
* raft::handle_t handle;
* int n_samples = 5000;
* int n_features = 50;
*
* auto input = raft::make_device_matrix<float>(handle, n_samples, n_features);
* auto labels = raft::make_device_vector<int>(handle, n_samples);
* auto output = raft::make_device_matrix<float>(handle, n_samples, n_samples);
*
* raft::random::make_blobs(handle, input.view(), labels.view());
* auto metric = raft::distance::DistanceType::L2SqrtExpanded;
* raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric);
* @endcode
*
* @tparam DistanceType which distance to evaluate
* @tparam InType input argument type
* @tparam AccType accumulation type
* @tparam OutType output type
* @tparam Index_ Index type
* @param handle raft handle for managing expensive resources
* @param x first set of points (size n*k)
* @param y second set of points (size m*k)
* @param dist output distance matrix (size n*m)
* @param metric_arg metric argument (used for Minkowski distance)
*/
template <raft::distance::DistanceType distanceType,
typename InType,
typename AccType,
typename OutType,
typename layout = raft::layout_c_contiguous,
typename Index_ = int>
void distance(raft::handle_t const& handle,
raft::device_matrix_view<InType, Index_, layout> const x,
raft::device_matrix_view<InType, Index_, layout> const y,
raft::device_matrix_view<OutType, Index_, layout> dist,
InType metric_arg = 2.0f)
{
RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal.");
RAFT_EXPECTS(dist.extent(0) == x.extent(0),
"Number of rows in output must be equal to "
"number of rows in X");
RAFT_EXPECTS(dist.extent(1) == y.extent(0),
"Number of columns in output must be equal to "
"number of rows in Y");

RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous.");
RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous.");

constexpr auto is_rowmajor = std::is_same_v<layout, layout_c_contiguous>;

distance<distanceType, InType, AccType, OutType, Index_>(x.data_handle(),
y.data_handle(),
dist.data_handle(),
x.extent(0),
y.extent(0),
x.extent(1),
handle.get_stream(),
is_rowmajor,
metric_arg);
}

/**
* @brief Convenience wrapper around 'distance' prim to convert runtime metric
* into compile time for the purpose of dispatch
Expand Down Expand Up @@ -449,9 +476,9 @@ void pairwise_distance(raft::handle_t const& handle,
metric_arg);
}

/** @} */

}; // namespace distance
}; // namespace raft

/** @} */

#endif
Loading