Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paying down some tech debt on docs, runtime API, and cython #1055

Merged
merged 39 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a7822ff
Starting to include specific function overloads
cjnolet Nov 18, 2022
dbbeb73
Fixing compile error
cjnolet Nov 18, 2022
40742ac
Merge remote-tracking branch 'rapidsai/branch-22.12' into doc-2212-re…
cjnolet Nov 18, 2022
09bdf84
Fixing up docs for pylibraft
cjnolet Nov 18, 2022
c52c4c3
Fixing compile error
cjnolet Nov 18, 2022
a045ece
More updates. Adding headers for many of the files
cjnolet Nov 23, 2022
7260bb5
Follow-on fixes
cjnolet Nov 30, 2022
247586c
Merge branch 'branch-23.02' into doc-2212-remove_broken_docs
cjnolet Nov 30, 2022
742cce8
Separating some of the distance APIs into groups
cjnolet Nov 30, 2022
9eec82e
Fixes
cjnolet Nov 30, 2022
ba08b57
Fixing pydoc for pylibraft code examples
cjnolet Dec 1, 2022
f511446
Adding more grouping
cjnolet Dec 1, 2022
6756d8b
Creating and using doxygen groups for matrix and linalg. Stats to come
cjnolet Dec 1, 2022
b8d262a
Merge branch 'branch-22.12' into doc-2212-remove_broken_docs
cjnolet Dec 1, 2022
56cf720
Merge branch 'branch-23.02' into doc-2212-remove_broken_docs
cjnolet Dec 1, 2022
802e7fa
Separating stats into doxygen groups
cjnolet Dec 2, 2022
e70bdc5
Fixing some broken doxygen groupds
cjnolet Dec 2, 2022
3a2e72c
Fix v measure
cjnolet Dec 2, 2022
47fe901
Breaking down some of the categories to make docs easier to consume
cjnolet Dec 2, 2022
1962cd9
Removing random state from datagen category
cjnolet Dec 2, 2022
7766a0b
Another fix
cjnolet Dec 2, 2022
59d4583
Consolidating mdspan cython defitions
cjnolet Dec 2, 2022
6899197
Removing unneeded factory function added during troubleshooting
cjnolet Dec 2, 2022
4a315a8
Removing more unused stuff
cjnolet Dec 2, 2022
da8ad8d
Removing typedefs
cjnolet Dec 2, 2022
16b7b9e
Update cpp/include/raft/stats/kl_divergence.cuh
cjnolet Dec 2, 2022
74efe0f
Update docs/source/cpp_api/mdspan_mdspan.rst
cjnolet Dec 2, 2022
ef4e1c7
Moving a bunch of stuff around
cjnolet Dec 2, 2022
794c157
Removing things from nn cmake
cjnolet Dec 2, 2022
c9d1b7b
Merge branch 'imp-2302-consolidate_mdspan_pxd' into doc-2212-remove_b…
cjnolet Dec 2, 2022
2ad0216
Fixing test syntax
cjnolet Dec 2, 2022
7a8689b
Merge branch 'doc-2212-remove_broken_docs' of github.com:cjnolet/raft…
cjnolet Dec 2, 2022
5f95808
Fixing build
cjnolet Dec 2, 2022
bcb57af
Fixing build
cjnolet Dec 2, 2022
a6feb71
Fixing bad merge
cjnolet Dec 2, 2022
d253981
Fixing build
cjnolet Dec 2, 2022
b325274
Merge branch 'imp-2302-consolidate_mdspan_pxd' into doc-2212-remove_b…
cjnolet Dec 2, 2022
64e2ab5
Rng state
cjnolet Dec 3, 2022
02100b1
Fixing another pxd error
cjnolet Dec 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 85 additions & 58 deletions cpp/include/raft/distance/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@

#include <raft/core/device_mdspan.hpp>

namespace raft {
namespace distance {

/**
* @defgroup pairwise_distance pairwise distance prims
* @defgroup pairwise_distance pointer-based pairwise distance prims
* @{
*/

namespace raft {
namespace distance {

/**
* @brief Evaluate pairwise distances with the user epilogue lamba allowed
* @tparam DistanceType which distance to evaluate
Expand Down Expand Up @@ -219,58 +219,6 @@ void distance(const InType* x,
x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor, metric_arg);
}

/**
* @brief Evaluate pairwise distances for the simple use case.
*
* Note: Only contiguous row- or column-major layouts supported currently.
*
* @tparam DistanceType which distance to evaluate
* @tparam InType input argument type
* @tparam AccType accumulation type
* @tparam OutType output type
* @tparam Index_ Index type
* @param handle raft handle for managing expensive resources
* @param x first set of points (size n*k)
* @param y second set of points (size m*k)
* @param dist output distance matrix (size n*m)
* @param metric_arg metric argument (used for Minkowski distance)
*/
template <raft::distance::DistanceType distanceType,
typename InType,
typename AccType,
typename OutType,
typename layout = raft::layout_c_contiguous,
typename Index_ = int>
void distance(raft::handle_t const& handle,
raft::device_matrix_view<InType, Index_, layout> const x,
raft::device_matrix_view<InType, Index_, layout> const y,
raft::device_matrix_view<OutType, Index_, layout> dist,
InType metric_arg = 2.0f)
{
RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal.");
RAFT_EXPECTS(dist.extent(0) == x.extent(0),
"Number of rows in output must be equal to "
"number of rows in X");
RAFT_EXPECTS(dist.extent(1) == y.extent(0),
"Number of columns in output must be equal to "
"number of rows in Y");

RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous.");
RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous.");

constexpr auto is_rowmajor = std::is_same_v<layout, layout_c_contiguous>;

distance<distanceType, InType, AccType, OutType, Index_>(x.data_handle(),
y.data_handle(),
dist.data_handle(),
x.extent(0),
y.extent(0),
x.extent(1),
handle.get_stream(),
is_rowmajor,
metric_arg);
}

/**
* @brief Convenience wrapper around 'distance' prim to convert runtime metric
* into compile time for the purpose of dispatch
Expand Down Expand Up @@ -401,6 +349,85 @@ void pairwise_distance(const raft::handle_t& handle,
handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg);
}

/** @} */

/**
* \defgroup distance_mdspan Pairwise distance functions
* @{
*/

/**
* @brief Evaluate pairwise distances for the simple use case.
*
* Note: Only contiguous row- or column-major layouts supported currently.
*
* Usage example:
* @code{.cpp}
* #include <raft/core/handle.hpp>
* #include <raft/core/device_mdarray.hpp>
* #include <raft/random/make_blobs.cuh>
* #include <raft/distance/distance.cuh>
*
* raft::handle_t handle;
* int n_samples = 5000;
* int n_features = 50;
*
* auto input = raft::make_device_matrix<float>(handle, n_samples, n_features);
* auto labels = raft::make_device_vector<int>(handle, n_samples);
* auto output = raft::make_device_matrix<float>(handle, n_samples, n_samples);
*
* raft::random::make_blobs(handle, input.view(), labels.view());
* auto metric = raft::distance::DistanceType::L2SqrtExpanded;
* raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric);
* @endcode
*
* @tparam DistanceType which distance to evaluate
* @tparam InType input argument type
* @tparam AccType accumulation type
* @tparam OutType output type
* @tparam Index_ Index type
* @param handle raft handle for managing expensive resources
* @param x first set of points (size n*k)
* @param y second set of points (size m*k)
* @param dist output distance matrix (size n*m)
* @param metric_arg metric argument (used for Minkowski distance)
*/
template <raft::distance::DistanceType distanceType,
typename InType,
typename AccType,
typename OutType,
typename layout = raft::layout_c_contiguous,
typename Index_ = int>
void distance(raft::handle_t const& handle,
raft::device_matrix_view<InType, Index_, layout> const x,
raft::device_matrix_view<InType, Index_, layout> const y,
raft::device_matrix_view<OutType, Index_, layout> dist,
InType metric_arg = 2.0f)
{
RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal.");
RAFT_EXPECTS(dist.extent(0) == x.extent(0),
"Number of rows in output must be equal to "
"number of rows in X");
RAFT_EXPECTS(dist.extent(1) == y.extent(0),
"Number of columns in output must be equal to "
"number of rows in Y");

RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous.");
RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous.");

constexpr auto is_rowmajor = std::is_same_v<layout, layout_c_contiguous>;

distance<distanceType, InType, AccType, OutType, Index_>(x.data_handle(),
y.data_handle(),
dist.data_handle(),
x.extent(0),
y.extent(0),
x.extent(1),
handle.get_stream(),
is_rowmajor,
metric_arg);
}

/**
* @brief Convenience wrapper around 'distance' prim to convert runtime metric
* into compile time for the purpose of dispatch
Expand Down Expand Up @@ -449,9 +476,9 @@ void pairwise_distance(raft::handle_t const& handle,
metric_arg);
}

/** @} */

}; // namespace distance
}; // namespace raft

/** @} */

#endif
12 changes: 12 additions & 0 deletions cpp/include/raft/distance/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@

namespace raft {
namespace distance {
/**
* \defgroup fused_l2_nn Fused 1-nearest neighbors
* @{
*/

template <typename LabelT, typename DataT>
using KVPMinReduce = detail::KVPMinReduceImpl<LabelT, DataT>;
Expand All @@ -40,6 +44,8 @@ using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl<LabelT, DataT>
template <typename LabelT, typename DataT>
using MinReduceOp = detail::MinReduceOpImpl<LabelT, DataT>;

/** @} */

/**
* Initialize array using init value from reduction op
*/
Expand All @@ -49,6 +55,10 @@ void initialize(const raft::handle_t& handle, OutT* min, IdxT m, DataT maxVal, R
detail::initialize<DataT, OutT, IdxT, ReduceOpT>(min, m, maxVal, redOp, handle.get_stream());
}

/**
* \ingroup fused_l2_nn
* @{
*/
/**
* @brief Fused L2 distance and 1-nearest-neighbor computation in a single call.
*
Expand Down Expand Up @@ -211,6 +221,8 @@ void fusedL2NNMinReduce(OutT* min,
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}

/** @} */

} // namespace distance
} // namespace raft

Expand Down
9 changes: 1 addition & 8 deletions cpp/include/raft/linalg/add.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@
#ifndef __ADD_H
#define __ADD_H

/**
* @defgroup arithmetic Dense matrix arithmetic
* @{
*/

#pragma once

#include "detail/add.cuh"
Expand Down Expand Up @@ -94,7 +89,7 @@ void addDevScalar(
}

/**
* @defgroup add Addition Arithmetic
* @defgroup add_dense Addition Arithmetic
* @{
*/

Expand Down Expand Up @@ -226,6 +221,4 @@ void add_scalar(const raft::handle_t& handle,
}; // end namespace linalg
}; // end namespace raft

/** @} */

#endif
2 changes: 1 addition & 1 deletion cpp/include/raft/linalg/axpy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void axpy(const raft::handle_t& handle,
}

/**
* @defgroup axpy axpy
* @defgroup axpy axpy routine
* @{
*/

Expand Down
9 changes: 9 additions & 0 deletions cpp/include/raft/linalg/dot.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
#include <raft/core/host_mdspan.hpp>

namespace raft::linalg {

/**
* @defgroup dot BLAS dot routine
* @{
*/

/**
* @brief Computes the dot product of two vectors.
* @param[in] handle raft::handle_t
Expand Down Expand Up @@ -84,5 +90,8 @@ void dot(const raft::handle_t& handle,
out.data_handle(),
handle.get_stream()));
}

/** @} */ // end of group dot

} // namespace raft::linalg
#endif
10 changes: 5 additions & 5 deletions cpp/include/raft/linalg/eig.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
namespace raft {
namespace linalg {

/**
* @defgroup eig Eigen Decomposition Methods
* @{
*/

/**
* @brief eig decomp with divide and conquer method for the column-major
* symmetric matrices
Expand Down Expand Up @@ -115,6 +110,11 @@ void eigJacobi(const raft::handle_t& handle,
detail::eigJacobi(handle, in, n_rows, n_cols, eig_vectors, eig_vals, stream, tol, sweeps);
}

/**
* @defgroup eig Eigen Decomposition Methods
* @{
*/

/**
* @brief eig decomp with divide and conquer method for the column-major
* symmetric matrices
Expand Down
9 changes: 4 additions & 5 deletions cpp/include/raft/linalg/map_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@

namespace raft::linalg {

/**
* @defgroup map_reduce Map-Reduce ops
* @{
*/

/**
* @brief CUDA version of map and then generic reduction operation
* @tparam Type data-type upon which the math operation will be performed
Expand Down Expand Up @@ -67,6 +62,10 @@ void mapReduce(OutType* out,
out, len, neutral, map, op, stream, in, args...);
}

/**
* @defgroup map_reduce Map-Reduce ops
* @{
*/
/**
* @brief CUDA version of map and then generic reduction operation
* @tparam InValueType the data-type of the input
Expand Down
8 changes: 8 additions & 0 deletions cpp/include/raft/linalg/matrix_vector.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@

namespace raft::linalg {

/**
* @defgroup matrix_vector Matrix-Vector Operations
* @{
*/

/**
* @brief multiply each row or column of matrix with vector, skipping zeros in vector
* @param [in] handle: raft handle for managing library resources
Expand Down Expand Up @@ -191,4 +196,7 @@ void binary_sub(const raft::handle_t& handle,
bcast_along_rows,
handle.get_stream());
}

/** @} */ // end of matrix_vector

} // namespace raft::linalg
7 changes: 7 additions & 0 deletions cpp/include/raft/linalg/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ void colNorm(Type* dots,
detail::colNormCaller(dots, data, D, N, type, rowMajor, stream, fin_op);
}

/**
* @defgroup norm Row- or Col-norm computation
* @{
*/

/**
* @brief Compute norm of the input matrix and perform fin_op
* @tparam ElementType Input/Output data type
Expand Down Expand Up @@ -142,6 +147,8 @@ void norm(const raft::handle_t& handle,
}
}

/** @} */

}; // end namespace linalg
}; // end namespace raft

Expand Down
7 changes: 7 additions & 0 deletions cpp/include/raft/linalg/normalize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
namespace raft {
namespace linalg {

/**
* @defgroup norm Row- or Col-norm computation
* @{
*/

/**
* @brief Divide rows by their norm defined by main_op, reduce_op and fin_op
*
Expand Down Expand Up @@ -127,5 +132,7 @@ void row_normalize(const raft::handle_t& handle,
}
}

/** @} */

} // namespace linalg
} // namespace raft
Loading