Skip to content

Commit

Permalink
Paying down some tech debt on docs, runtime API, and cython (#1055)
Browse files Browse the repository at this point in the history
This PR pays down some tech debt and cleans up some things. On the surface, you'll notice many files have been touched or modified but the modifications are largely confined to a few major categories of changes:

1. Fixes some of the issues found with the doc updates in 22.12. 

2. Breaks some of the docs for the c++ namespaces down into multiple sections to make them easier to navigate and consume

3. Renames raft_distance directory into more appropriately named raft_runtime. (This is also in preparation to eventually rename the libraft-distance library into libraft once we can remove the FAISS dependency.

4. Separates out some runtime source files and APIs that were being mistakenly combined with the template specializations API

5. Consolidates multiple mdspan.pxd files into a single file. 

6. Consistently uses `cpp` directory for new(er) pxd files, nested into their respective packages.

7. Consistently uses doxygen groups in many of the namespaces.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Ben Frederickson (https://github.com/benfred)
  - Divye Gala (https://github.com/divyegala)

URL: #1055
  • Loading branch information
cjnolet authored Dec 6, 2022
1 parent bb888f4 commit 092c515
Show file tree
Hide file tree
Showing 229 changed files with 3,541 additions and 1,718 deletions.
165 changes: 75 additions & 90 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -278,81 +278,83 @@ set_target_properties(raft_distance PROPERTIES EXPORT_NAME distance)
if(RAFT_COMPILE_DIST_LIBRARY)
add_library(
raft_distance_lib
src/distance/pairwise_distance.cu
src/distance/fused_l2_min_arg.cu
src/distance/update_centroids_float.cu
src/distance/update_centroids_double.cu
src/distance/cluster_cost_float.cu
src/distance/cluster_cost_double.cu
src/distance/kmeans_fit_float.cu
src/distance/kmeans_fit_double.cu
src/distance/specializations/detail/canberra.cu
src/distance/specializations/detail/chebyshev.cu
src/distance/specializations/detail/correlation.cu
src/distance/specializations/detail/cosine.cu
src/distance/specializations/detail/cosine.cu
src/distance/specializations/detail/hamming_unexpanded.cu
src/distance/specializations/detail/hellinger_expanded.cu
src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu
src/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu
src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu
src/distance/specializations/detail/kernels/gram_matrix_base_double.cu
src/distance/specializations/detail/kernels/gram_matrix_base_float.cu
src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu
src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu
src/distance/distance/pairwise_distance.cu
src/distance/distance/fused_l2_min_arg.cu
src/distance/cluster/update_centroids_float.cu
src/distance/cluster/update_centroids_double.cu
src/distance/cluster/cluster_cost_float.cu
src/distance/cluster/cluster_cost_double.cu
src/distance/neighbors/refine.cu
src/distance/neighbors/ivfpq_search.cu
src/distance/cluster/kmeans_fit_float.cu
src/distance/cluster/kmeans_fit_double.cu
src/distance/distance/specializations/detail/canberra.cu
src/distance/distance/specializations/detail/chebyshev.cu
src/distance/distance/specializations/detail/correlation.cu
src/distance/distance/specializations/detail/cosine.cu
src/distance/distance/specializations/detail/cosine.cu
src/distance/distance/specializations/detail/hamming_unexpanded.cu
src/distance/distance/specializations/detail/hellinger_expanded.cu
src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu
src/distance/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu
src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu
src/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu
src/distance/distance/specializations/detail/kernels/gram_matrix_base_float.cu
src/distance/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu
src/distance/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu
# These are somehow missing a kernel definition which is causing a compile error.
# src/distance/specializations/detail/kernels/rbf_kernel_double.cu
# src/distance/specializations/detail/kernels/rbf_kernel_float.cu
src/distance/specializations/detail/kernels/tanh_kernel_double.cu
src/distance/specializations/detail/kernels/tanh_kernel_float.cu
src/distance/specializations/detail/kl_divergence_float_float_float_int.cu
src/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu
src/distance/specializations/detail/kl_divergence_double_double_double_int.cu
src/distance/specializations/detail/l1_float_float_float_int.cu
src/distance/specializations/detail/l1_float_float_float_uint32.cu
src/distance/specializations/detail/l1_double_double_double_int.cu
src/distance/specializations/detail/l2_expanded_float_float_float_int.cu
src/distance/specializations/detail/l2_expanded_float_float_float_uint32.cu
src/distance/specializations/detail/l2_expanded_double_double_double_int.cu
src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu
src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_uint32.cu
src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_uint32.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/l2_unexpanded_float_float_float_uint32.cu
src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu
src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/russel_rao_double_double_double_int.cu
src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu
src/distance/specializations/detail/russel_rao_float_float_float_int.cu
src/distance/specializations/fused_l2_nn_double_int.cu
src/distance/specializations/fused_l2_nn_double_int64.cu
src/distance/specializations/fused_l2_nn_float_int.cu
src/distance/specializations/fused_l2_nn_float_int64.cu
src/nn/specializations/detail/ivfpq_build.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_search.cu
src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu
src/nn/specializations/refine.cu
src/random/specializations/rmat_rectangular_generator_int_double.cu
src/random/specializations/rmat_rectangular_generator_int64_double.cu
src/random/specializations/rmat_rectangular_generator_int_float.cu
src/random/specializations/rmat_rectangular_generator_int64_float.cu
src/distance/distance/specializations/detail/kernels/tanh_kernel_double.cu
src/distance/distance/specializations/detail/kernels/tanh_kernel_float.cu
src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu
src/distance/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu
src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu
src/distance/distance/specializations/detail/l1_float_float_float_int.cu
src/distance/distance/specializations/detail/l1_float_float_float_uint32.cu
src/distance/distance/specializations/detail/l1_double_double_double_int.cu
src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu
src/distance/distance/specializations/detail/l2_expanded_float_float_float_uint32.cu
src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu
src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu
src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_uint32.cu
src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu
src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu
src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_uint32.cu
src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu
src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu
src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_uint32.cu
src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu
src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu
src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu
src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu
src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu
src/distance/distance/specializations/detail/russel_rao_float_float_float_uint32.cu
src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu
src/distance/distance/specializations/fused_l2_nn_double_int.cu
src/distance/distance/specializations/fused_l2_nn_double_int64.cu
src/distance/distance/specializations/fused_l2_nn_float_int.cu
src/distance/distance/specializations/fused_l2_nn_float_int64.cu
src/distance/neighbors/ivfpq_build.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_fast.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_half_fast.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
src/distance/neighbors/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
src/distance/neighbors/specializations/detail/ivfpq_search_float_int64_t.cu
src/distance/neighbors/specializations/detail/ivfpq_search_float_uint64_t.cu
src/distance/neighbors/specializations/detail/ivfpq_search_float_uint32_t.cu
src/distance/random/rmat_rectangular_generator_int_double.cu
src/distance/random/rmat_rectangular_generator_int64_double.cu
src/distance/random/rmat_rectangular_generator_int_float.cu
src/distance/random/rmat_rectangular_generator_int64_float.cu
)
set_target_properties(
raft_distance_lib
Expand Down Expand Up @@ -410,23 +412,6 @@ if(RAFT_COMPILE_NN_LIBRARY)
src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu
src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu
src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_build.cu
src/nn/specializations/detail/ivfpq_search.cu
src/nn/specializations/detail/ivfpq_search_float_int64_t.cu
src/nn/specializations/detail/ivfpq_search_float_uint32_t.cu
src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu
src/nn/specializations/fused_l2_knn_long_float_true.cu
src/nn/specializations/fused_l2_knn_long_float_false.cu
src/nn/specializations/fused_l2_knn_int_float_true.cu
Expand Down Expand Up @@ -519,7 +504,7 @@ if(TARGET raft_distance_lib)
EXPORT raft-distance-lib-exports
)
install(
DIRECTORY include/raft_distance
DIRECTORY include/raft_runtime
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
COMPONENT distance
)
Expand Down
143 changes: 85 additions & 58 deletions cpp/include/raft/distance/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@

#include <raft/core/device_mdspan.hpp>

namespace raft {
namespace distance {

/**
* @defgroup pairwise_distance pairwise distance prims
* @defgroup pairwise_distance pointer-based pairwise distance prims
* @{
*/

namespace raft {
namespace distance {

/**
* @brief Evaluate pairwise distances with the user epilogue lamba allowed
* @tparam DistanceType which distance to evaluate
Expand Down Expand Up @@ -219,58 +219,6 @@ void distance(const InType* x,
x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor, metric_arg);
}

/**
* @brief Evaluate pairwise distances for the simple use case.
*
* Note: Only contiguous row- or column-major layouts supported currently.
*
* @tparam DistanceType which distance to evaluate
* @tparam InType input argument type
* @tparam AccType accumulation type
* @tparam OutType output type
* @tparam Index_ Index type
* @param handle raft handle for managing expensive resources
* @param x first set of points (size n*k)
* @param y second set of points (size m*k)
* @param dist output distance matrix (size n*m)
* @param metric_arg metric argument (used for Minkowski distance)
*/
template <raft::distance::DistanceType distanceType,
typename InType,
typename AccType,
typename OutType,
typename layout = raft::layout_c_contiguous,
typename Index_ = int>
void distance(raft::handle_t const& handle,
raft::device_matrix_view<InType, Index_, layout> const x,
raft::device_matrix_view<InType, Index_, layout> const y,
raft::device_matrix_view<OutType, Index_, layout> dist,
InType metric_arg = 2.0f)
{
RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal.");
RAFT_EXPECTS(dist.extent(0) == x.extent(0),
"Number of rows in output must be equal to "
"number of rows in X");
RAFT_EXPECTS(dist.extent(1) == y.extent(0),
"Number of columns in output must be equal to "
"number of rows in Y");

RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous.");
RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous.");

constexpr auto is_rowmajor = std::is_same_v<layout, layout_c_contiguous>;

distance<distanceType, InType, AccType, OutType, Index_>(x.data_handle(),
y.data_handle(),
dist.data_handle(),
x.extent(0),
y.extent(0),
x.extent(1),
handle.get_stream(),
is_rowmajor,
metric_arg);
}

/**
* @brief Convenience wrapper around 'distance' prim to convert runtime metric
* into compile time for the purpose of dispatch
Expand Down Expand Up @@ -401,6 +349,85 @@ void pairwise_distance(const raft::handle_t& handle,
handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg);
}

/** @} */

/**
* \defgroup distance_mdspan Pairwise distance functions
* @{
*/

/**
* @brief Evaluate pairwise distances for the simple use case.
*
* Note: Only contiguous row- or column-major layouts supported currently.
*
* Usage example:
* @code{.cpp}
* #include <raft/core/handle.hpp>
* #include <raft/core/device_mdarray.hpp>
* #include <raft/random/make_blobs.cuh>
* #include <raft/distance/distance.cuh>
*
* raft::handle_t handle;
* int n_samples = 5000;
* int n_features = 50;
*
* auto input = raft::make_device_matrix<float>(handle, n_samples, n_features);
* auto labels = raft::make_device_vector<int>(handle, n_samples);
* auto output = raft::make_device_matrix<float>(handle, n_samples, n_samples);
*
* raft::random::make_blobs(handle, input.view(), labels.view());
* auto metric = raft::distance::DistanceType::L2SqrtExpanded;
* raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric);
* @endcode
*
* @tparam DistanceType which distance to evaluate
* @tparam InType input argument type
* @tparam AccType accumulation type
* @tparam OutType output type
* @tparam Index_ Index type
* @param handle raft handle for managing expensive resources
* @param x first set of points (size n*k)
* @param y second set of points (size m*k)
* @param dist output distance matrix (size n*m)
* @param metric_arg metric argument (used for Minkowski distance)
*/
template <raft::distance::DistanceType distanceType,
typename InType,
typename AccType,
typename OutType,
typename layout = raft::layout_c_contiguous,
typename Index_ = int>
void distance(raft::handle_t const& handle,
raft::device_matrix_view<InType, Index_, layout> const x,
raft::device_matrix_view<InType, Index_, layout> const y,
raft::device_matrix_view<OutType, Index_, layout> dist,
InType metric_arg = 2.0f)
{
RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal.");
RAFT_EXPECTS(dist.extent(0) == x.extent(0),
"Number of rows in output must be equal to "
"number of rows in X");
RAFT_EXPECTS(dist.extent(1) == y.extent(0),
"Number of columns in output must be equal to "
"number of rows in Y");

RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous.");
RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous.");

constexpr auto is_rowmajor = std::is_same_v<layout, layout_c_contiguous>;

distance<distanceType, InType, AccType, OutType, Index_>(x.data_handle(),
y.data_handle(),
dist.data_handle(),
x.extent(0),
y.extent(0),
x.extent(1),
handle.get_stream(),
is_rowmajor,
metric_arg);
}

/**
* @brief Convenience wrapper around 'distance' prim to convert runtime metric
* into compile time for the purpose of dispatch
Expand Down Expand Up @@ -449,9 +476,9 @@ void pairwise_distance(raft::handle_t const& handle,
metric_arg);
}

/** @} */

}; // namespace distance
}; // namespace raft

/** @} */

#endif
Loading

0 comments on commit 092c515

Please sign in to comment.