diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index 4d6c1836fc..2fc43e2a05 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -44,18 +44,23 @@ using managed_mdspan = mdspan -struct is_device_accessible_mdspan : std::false_type { +struct is_device_mdspan : std::false_type { }; template -struct is_device_accessible_mdspan - : std::bool_constant { +struct is_device_mdspan : std::bool_constant { }; /** * @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type */ template -using is_device_accessible_mdspan_t = is_device_accessible_mdspan>; +using is_device_mdspan_t = is_device_mdspan>; + +template +using is_input_device_mdspan_t = is_device_mdspan>; + +template +using is_output_device_mdspan_t = is_device_mdspan>; template struct is_managed_mdspan : std::false_type { @@ -70,6 +75,12 @@ struct is_managed_mdspan : std::bool_constant using is_managed_mdspan_t = is_managed_mdspan>; +template +using is_input_managed_mdspan_t = is_managed_mdspan>; + +template +using is_output_managed_mdspan_t = is_managed_mdspan>; + } // end namespace detail /** @@ -77,11 +88,24 @@ using is_managed_mdspan_t = is_managed_mdspan>; * derived type */ template -inline constexpr bool is_device_accessible_mdspan_v = - std::conjunction_v...>; +inline constexpr bool is_device_mdspan_v = std::conjunction_v...>; + +template +inline constexpr bool is_input_device_mdspan_v = + std::conjunction_v...>; + +template +inline constexpr bool is_output_device_mdspan_v = + std::conjunction_v...>; template -using enable_if_device_mdspan = std::enable_if_t>; +using enable_if_device_mdspan = std::enable_if_t>; + +template +using enable_if_input_device_mdspan = std::enable_if_t>; + +template +using enable_if_output_device_mdspan = std::enable_if_t>; /** * @\brief Boolean to determine if variadic template types Tn are either raft::managed_mdspan or a @@ -90,9 +114,23 @@ using enable_if_device_mdspan = std::enable_if_t inline constexpr bool is_managed_mdspan_v = std::conjunction_v...>; +template +inline constexpr bool is_input_managed_mdspan_v = + std::conjunction_v...>; + +template +inline constexpr bool is_output_managed_mdspan_v = + std::conjunction_v...>; + template using enable_if_managed_mdspan = std::enable_if_t>; +template +using enable_if_input_managed_mdspan = std::enable_if_t>; + +template +using enable_if_output_managed_mdspan = std::enable_if_t>; + /** * @brief Shorthand for 0-dim host mdspan (scalar). * @tparam ElementType the data type of the scalar element diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index e6ab22004e..fc2a9bbd6d 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -37,18 +37,23 @@ using host_mdspan = mdspan -struct is_host_accessible_mdspan : std::false_type { +struct is_host_mdspan : std::false_type { }; template -struct is_host_accessible_mdspan - : std::bool_constant { +struct is_host_mdspan : std::bool_constant { }; /** * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type */ template -using is_host_accessible_mdspan_t = is_host_accessible_mdspan>; +using is_host_mdspan_t = is_host_mdspan>; + +template +using is_input_host_mdspan_t = is_host_mdspan>; + +template +using is_output_host_mdspan_t = is_host_mdspan>; } // namespace detail @@ -57,11 +62,24 @@ using is_host_accessible_mdspan_t = is_host_accessible_mdspan> * derived type */ template -inline constexpr bool is_host_accessible_mdspan_v = - std::conjunction_v...>; +inline constexpr bool is_host_mdspan_v = std::conjunction_v...>; + +template +inline constexpr bool is_input_host_mdspan_v = + std::conjunction_v...>; + +template +inline constexpr bool is_output_host_mdspan_v = + std::conjunction_v...>; + +template +using enable_if_host_mdspan = std::enable_if_t>; + +template +using enable_if_input_host_mdspan = std::enable_if_t>; template -using enable_if_host_mdspan = std::enable_if_t>; +using enable_if_output_host_mdspan = std::enable_if_t>; /** * @brief Shorthand for 0-dim host mdspan (scalar). diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 7169a010b6..6281ca98ea 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -57,9 +57,31 @@ struct is_mdspan( : std::true_type { }; +template +struct is_input_mdspan : std::false_type { +}; +template +struct is_input_mdspan()))>> + : std::bool_constant> { +}; + +template +struct is_output_mdspan : std::false_type { +}; +template +struct is_output_mdspan()))>> + : std::bool_constant> { +}; + template using is_mdspan_t = is_mdspan>; +template +using is_input_mdspan_t = is_input_mdspan; + +template +using is_output_mdspan_t = is_output_mdspan; + /** * @\brief Boolean to determine if variadic template types Tn are either * raft::host_mdspan/raft::device_mdspan or their derived types @@ -70,6 +92,18 @@ inline constexpr bool is_mdspan_v = std::conjunction_v...>; template using enable_if_mdspan = std::enable_if_t>; +template +inline constexpr bool is_input_mdspan_v = std::conjunction_v...>; + +template +using enable_if_input_mdspan = std::enable_if_t>; + +template +inline constexpr bool is_output_mdspan_v = std::conjunction_v...>; + +template +using enable_if_output_mdspan = std::enable_if_t>; + // uint division optimization inspired by the CIndexer in cupy. Division operation is // slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64 // bit when the index is smaller, then try to avoid division when it's exp of 2. diff --git a/cpp/include/raft/linalg/add.cuh b/cpp/include/raft/linalg/add.cuh index e25c9df9ef..9f1d5d4a33 100644 --- a/cpp/include/raft/linalg/add.cuh +++ b/cpp/include/raft/linalg/add.cuh @@ -25,6 +25,10 @@ #include "detail/add.cuh" +#include +#include +#include + namespace raft { namespace linalg { @@ -46,7 +50,7 @@ using detail::adds_scalar; * @param stream cuda stream where to launch work */ template -void addScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) +void addScalar(OutT* out, const InT* in, const InT scalar, IdxType len, cudaStream_t stream) { detail::addScalar(out, in, scalar, len, stream); } @@ -72,7 +76,9 @@ void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t st /** Substract single value pointed by singleScalarDev parameter in device memory from inDev[i] and * write result to outDev[i] - * @tparam math_t data-type upon which the math operation will be performed + * @tparam InT input data-type. Also the data-type upon which the math ops + * will be performed + * @tparam OutT output data-type * @tparam IdxType Integer type used to for addressing * @param outDev the output buffer * @param inDev the input buffer @@ -80,16 +86,143 @@ void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t st * @param len number of elements in the input and output buffer * @param stream cuda stream */ -template -void addDevScalar(math_t* outDev, - const math_t* inDev, - const math_t* singleScalarDev, - IdxType len, - cudaStream_t stream) +template +void addDevScalar( + OutT* outDev, const InT* inDev, const InT* singleScalarDev, IdxType len, cudaStream_t stream) { detail::addDevScalar(outDev, inDev, singleScalarDev, len, stream); } +/** + * @defgroup add Addition Arithmetic + * @{ + */ + +/** + * @brief Elementwise add operation + * @tparam InType Input Type raft::device_mdspan + * @tparam OutType Output Type raft::device_mdspan + * @param[in] handle raft::handle_t + * @param[in] in1 First Input + * @param[in] in2 Second Input + * @param[out] out Output + */ +template , + typename = raft::enable_if_output_device_mdspan> +void add(const raft::handle_t& handle, InType in1, InType in2, OutType out) +{ + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in2), "Input 2 must be contiguous"); + RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(), + "Size mismatch between Output and Inputs"); + + if (out.size() <= std::numeric_limits::max()) { + add(out.data_handle(), + in1.data_handle(), + in2.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } else { + add(out.data_handle(), + in1.data_handle(), + in2.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } +} + +/** + * @brief Elementwise addition of device scalar to input + * @tparam InType Input Type raft::device_mdspan + * @tparam OutType Output Type raft::device_mdspan + * @tparam ScalarIdxType Index Type of scalar + * @param[in] handle raft::handle_t + * @param[in] in Input + * @param[in] scalar raft::device_scalar_view + * @param[in] out Output + */ +template , + typename = raft::enable_if_output_device_mdspan> +void add_scalar(const raft::handle_t& handle, + InType in, + OutType out, + raft::device_scalar_view scalar) +{ + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input"); + + if (out.size() <= std::numeric_limits::max()) { + addDevScalar(out.data_handle(), + in.data_handle(), + scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } else { + addDevScalar(out.data_handle(), + in.data_handle(), + scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } +} + +/** + * @brief Elementwise addition of host scalar to input + * @tparam InType Input Type raft::device_mdspan + * @tparam OutType Output Type raft::device_mdspan + * @tparam ScalarIdxType Index Type of scalar + * @param[in] handle raft::handle_t + * @param[in] in Input + * @param[in] scalar raft::host_scalar_view + * @param[in] out Output + */ +template , + typename = raft::enable_if_output_device_mdspan> +void add_scalar(const raft::handle_t& handle, + const InType in, + OutType out, + raft::host_scalar_view scalar) +{ + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input"); + + if (out.size() <= std::numeric_limits::max()) { + addScalar(out.data_handle(), + in.data_handle(), + *scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } else { + addScalar(out.data_handle(), + in.data_handle(), + *scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } +} + +/** @} */ // end of group add + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/axpy.cuh b/cpp/include/raft/linalg/axpy.cuh index 2e23047b5a..6d54f87e91 100644 --- a/cpp/include/raft/linalg/axpy.cuh +++ b/cpp/include/raft/linalg/axpy.cuh @@ -50,6 +50,84 @@ void axpy(const raft::handle_t& handle, detail::axpy(handle, n, alpha, x, incx, y, incy, stream); } +/** + * @defgroup axpy axpy + * @{ + */ + +/** + * @brief axpy function + * It computes the following equation: y = alpha * x + y + * + * @tparam InType Type raft::device_mdspan + * @tparam ScalarIdxType Index Type of scalar + * @param [in] handle raft::handle_t + * @param [in] alpha raft::device_scalar_view + * @param [in] x Input vector + * @param [inout] y Output vector + * @param [in] incx stride between consecutive elements of x + * @param [in] incy stride between consecutive elements of y + */ +template , + typename = raft::enable_if_output_device_mdspan> +void axpy(const raft::handle_t& handle, + raft::device_scalar_view alpha, + InType x, + OutType y, + const int incx, + const int incy) +{ + RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input") + + axpy(handle, + y.size(), + alpha.data_handle(), + x.data_handle(), + incx, + y.data_handle(), + incy, + handle.get_stream()); +} + +/** + * @brief axpy function + * It computes the following equation: y = alpha * x + y + * + * @tparam MdspanType Type raft::device_mdspan + * @tparam ScalarIdxType Index Type of scalar + * @param [in] handle raft::handle_t + * @param [in] alpha raft::device_scalar_view + * @param [in] x Input vector + * @param [inout] y Output vector + * @param [in] incx stride between consecutive elements of x + * @param [in] incy stride between consecutive elements of y + */ +template , + typename = raft::enable_if_output_device_mdspan> +void axpy(const raft::handle_t& handle, + raft::host_scalar_view alpha, + InType x, + OutType y, + const int incx, + const int incy) +{ + RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input") + + axpy(handle, + y.size(), + alpha.data_handle(), + x.data_handle(), + incx, + y.data_handle(), + incy, + handle.get_stream()); +} + +/** @} */ // end of group axpy + } // namespace raft::linalg #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/binary_op.cuh b/cpp/include/raft/linalg/binary_op.cuh index c3827f79bf..693ef961c2 100644 --- a/cpp/include/raft/linalg/binary_op.cuh +++ b/cpp/include/raft/linalg/binary_op.cuh @@ -20,7 +20,10 @@ #include "detail/binary_op.cuh" +#include +#include #include +#include namespace raft { namespace linalg { @@ -52,6 +55,51 @@ void binaryOp( detail::binaryOp(out, in1, in2, len, op, stream); } +/** + * @defgroup binary_op Element-Wise Binary Operation + * @{ + */ + +/** + * @brief perform element-wise binary operation on the input arrays + * @tparam InType Input Type raft::device_mdspan + * @tparam Lambda the device-lambda performing the actual operation + * @tparam OutType Output Type raft::device_mdspan + * @param[in] handle raft::handle_t + * @param[in] in1 First input + * @param[in] in2 Second input + * @param[out] out Output + * @param[in] op the device-lambda + * @note Lambda must be a functor with the following signature: + * `OutType func(const InType& val1, const InType& val2);` + */ +template , + typename = raft::enable_if_output_device_mdspan> +void binary_op(const raft::handle_t& handle, InType in1, InType in2, OutType out, Lambda op) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in2), "Input 2 must be contiguous"); + RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(), + "Size mismatch between Output and Inputs"); + + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + if (out.size() <= std::numeric_limits::max()) { + binaryOp( + out.data_handle(), in1.data_handle(), in2.data_handle(), out.size(), op, handle.get_stream()); + } else { + binaryOp( + out.data_handle(), in1.data_handle(), in2.data_handle(), out.size(), op, handle.get_stream()); + } +} + +/** @} */ // end of group binary_op + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/cholesky_r1_update.cuh b/cpp/include/raft/linalg/cholesky_r1_update.cuh index d8e838a634..f40866b235 100644 --- a/cpp/include/raft/linalg/cholesky_r1_update.cuh +++ b/cpp/include/raft/linalg/cholesky_r1_update.cuh @@ -25,6 +25,7 @@ namespace linalg { /** * @brief Rank 1 update of Cholesky decomposition. + * NOTE: The new mdspan-based API will not be provided for this function. * * This method is useful if an algorithm iteratively builds up matrix A, and * the Cholesky decomposition of A is required at each step. @@ -109,7 +110,7 @@ namespace linalg { * @param L device array for to store the triangular matrix L, and the new * column of A in column major format, size [n*n] * @param n number of elements in the new row. - * @param ld stride of colums in L + * @param ld stride of columns in L * @param workspace device pointer to workspace shall be nullptr ar an array * of size [n_bytes]. * @param n_bytes size of workspace is returned here if workspace==nullptr. @@ -132,6 +133,7 @@ void choleskyRank1Update(const raft::handle_t& handle, { detail::choleskyRank1Update(handle, L, n, ld, workspace, n_bytes, uplo, stream, eps); } + }; // namespace linalg }; // namespace raft diff --git a/cpp/include/raft/linalg/coalesced_reduction.cuh b/cpp/include/raft/linalg/coalesced_reduction.cuh index 03477f72d6..518667c5f1 100644 --- a/cpp/include/raft/linalg/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/coalesced_reduction.cuh @@ -20,6 +20,9 @@ #include "detail/coalesced_reduction.cuh" +#include +#include + namespace raft { namespace linalg { @@ -58,8 +61,8 @@ template > void coalescedReduction(OutType* dots, const InType* data, - int D, - int N, + IdxType D, + IdxType N, OutType init, cudaStream_t stream, bool inplace = false, @@ -70,6 +73,90 @@ void coalescedReduction(OutType* dots, detail::coalescedReduction(dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } +/** + * @defgroup coalesced_reduction Coalesced Memory Access Reductions + * For reducing along rows for col-major and along columns for row-major + * @{ + */ + +/** + * @brief Compute reduction of the input matrix along the leading dimension + * This API is to be used when the desired reduction is along the dimension + * of the memory layout. For example, a row-major matrix will be reduced + * along the columns whereas a column-major matrix will be reduced along + * the rows. + * + * @tparam InValueType the input data-type of underlying raft::matrix_view + * @tparam LayoutPolicy The layout of Input/Output (row or col major) + * @tparam OutValueType the output data-type of underlying raft::matrix_view and reduction + * @tparam IndexType Integer type used to for addressing + * @tparam MainLambda Unary lambda applied while acculumation (eg: L1 or L2 norm) + * It must be a 'callable' supporting the following input and output: + *
OutType (*MainLambda)(InType, IdxType);
+ * @tparam ReduceLambda Binary lambda applied for reduction (eg: addition(+) for L2 norm) + * It must be a 'callable' supporting the following input and output: + *
OutType (*ReduceLambda)(OutType);
+ * @tparam FinalLambda the final lambda applied before STG (eg: Sqrt for L2 norm) + * It must be a 'callable' supporting the following input and output: + *
OutType (*FinalLambda)(OutType);
+ * @param handle raft::handle_t + * @param[in] data Input of type raft::device_matrix_view + * @param[out] dots Output of type raft::device_matrix_view + * @param[in] init initial value to use for the reduction + * @param[in] inplace reduction result added inplace or overwrites old values? + * @param[in] main_op fused elementwise operation to apply before reduction + * @param[in] reduce_op fused binary reduction operation + * @param[in] final_op fused elementwise operation to apply before storing results + */ +template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::Nop> +void coalesced_reduction(const raft::handle_t& handle, + raft::device_matrix_view data, + raft::device_vector_view dots, + OutValueType init, + bool inplace = false, + MainLambda main_op = raft::Nop(), + ReduceLambda reduce_op = raft::Sum(), + FinalLambda final_op = raft::Nop()) +{ + if constexpr (std::is_same_v) { + RAFT_EXPECTS(static_cast(dots.size()) == data.extent(0), + "Output should be equal to number of rows in Input"); + + coalescedReduction(dots.data_handle(), + data.data_handle(), + data.extent(1), + data.extent(0), + init, + handle.get_stream(), + inplace, + main_op, + reduce_op, + final_op); + } else if constexpr (std::is_same_v) { + RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), + "Output should be equal to number of columns in Input"); + + coalescedReduction(dots.data_handle(), + data.data_handle(), + data.extent(0), + data.extent(1), + init, + handle.get_stream(), + inplace, + main_op, + reduce_op, + final_op); + } +} + +/** @} */ // end of group coalesced_reduction + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/detail/add.cuh b/cpp/include/raft/linalg/detail/add.cuh index 3cd583faa5..34966ebbc2 100644 --- a/cpp/include/raft/linalg/detail/add.cuh +++ b/cpp/include/raft/linalg/detail/add.cuh @@ -40,27 +40,24 @@ void add(OutT* out, const InT* in1, const InT* in2, IdxType len, cudaStream_t st raft::linalg::binaryOp(out, in1, in2, len, thrust::plus(), stream); } -template -__global__ void add_dev_scalar_kernel(math_t* outDev, - const math_t* inDev, - const math_t* singleScalarDev, +template +__global__ void add_dev_scalar_kernel(OutT* outDev, + const InT* inDev, + const InT* singleScalarDev, IdxType len) { IdxType i = ((IdxType)blockIdx.x * (IdxType)blockDim.x) + threadIdx.x; if (i < len) { outDev[i] = inDev[i] + *singleScalarDev; } } -template -void addDevScalar(math_t* outDev, - const math_t* inDev, - const math_t* singleScalarDev, - IdxType len, - cudaStream_t stream) +template +void addDevScalar( + OutT* outDev, const InT* inDev, const InT* singleScalarDev, IdxType len, cudaStream_t stream) { // TODO: block dimension has not been tuned dim3 block(256); dim3 grid(raft::ceildiv(len, (IdxType)block.x)); - add_dev_scalar_kernel<<>>(outDev, inDev, singleScalarDev, len); + add_dev_scalar_kernel<<>>(outDev, inDev, singleScalarDev, len); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/include/raft/linalg/detail/divide.cuh b/cpp/include/raft/linalg/detail/divide.cuh index cb46ae76de..333cd3e83c 100644 --- a/cpp/include/raft/linalg/detail/divide.cuh +++ b/cpp/include/raft/linalg/detail/divide.cuh @@ -17,16 +17,18 @@ #pragma once #include "functional.cuh" + +#include #include namespace raft { namespace linalg { namespace detail { -template -void divideScalar(math_t* out, const math_t* in, math_t scalar, IdxType len, cudaStream_t stream) +template +void divideScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { - raft::linalg::unaryOp(out, in, len, divides_scalar(scalar), stream); + raft::linalg::unaryOp(out, in, len, divides_scalar(scalar), stream); } }; // end namespace detail diff --git a/cpp/include/raft/linalg/detail/eig.cuh b/cpp/include/raft/linalg/detail/eig.cuh index dfd6bd4f7c..d48b42fc57 100644 --- a/cpp/include/raft/linalg/detail/eig.cuh +++ b/cpp/include/raft/linalg/detail/eig.cuh @@ -139,9 +139,9 @@ enum EigVecMemUsage { OVERWRITE_INPUT, COPY_INPUT }; template void eigSelDC(const raft::handle_t& handle, math_t* in, - int n_rows, - int n_cols, - int n_eig_vals, + std::size_t n_rows, + std::size_t n_cols, + std::size_t n_eig_vals, math_t* eig_vectors, math_t* eig_vals, EigVecMemUsage memUsage, @@ -156,13 +156,13 @@ void eigSelDC(const raft::handle_t& handle, CUSOLVER_EIG_MODE_VECTOR, CUSOLVER_EIG_RANGE_I, CUBLAS_FILL_MODE_UPPER, - n_rows, + static_cast(n_rows), in, - n_cols, + static_cast(n_cols), math_t(0.0), math_t(0.0), - n_cols - n_eig_vals + 1, - n_cols, + static_cast(n_cols - n_eig_vals + 1), + static_cast(n_cols), &h_meig, eig_vals, &lwork)); @@ -176,13 +176,13 @@ void eigSelDC(const raft::handle_t& handle, CUSOLVER_EIG_MODE_VECTOR, CUSOLVER_EIG_RANGE_I, CUBLAS_FILL_MODE_UPPER, - n_rows, + static_cast(n_rows), in, - n_cols, + static_cast(n_cols), math_t(0.0), math_t(0.0), - n_cols - n_eig_vals + 1, - n_cols, + static_cast(n_cols - n_eig_vals + 1), + static_cast(n_cols), &h_meig, eig_vals, d_work.data(), @@ -197,13 +197,13 @@ void eigSelDC(const raft::handle_t& handle, CUSOLVER_EIG_MODE_VECTOR, CUSOLVER_EIG_RANGE_I, CUBLAS_FILL_MODE_UPPER, - n_rows, + static_cast(n_rows), eig_vectors, - n_cols, + static_cast(n_cols), math_t(0.0), math_t(0.0), - n_cols - n_eig_vals + 1, - n_cols, + static_cast(n_cols - n_eig_vals + 1), + static_cast(n_cols), &h_meig, eig_vals, d_work.data(), @@ -230,8 +230,8 @@ void eigSelDC(const raft::handle_t& handle, template void eigJacobi(const raft::handle_t& handle, const math_t* in, - int n_rows, - int n_cols, + std::size_t n_rows, + std::size_t n_cols, math_t* eig_vectors, math_t* eig_vals, cudaStream_t stream, @@ -249,9 +249,9 @@ void eigJacobi(const raft::handle_t& handle, RAFT_CUSOLVER_TRY(cusolverDnsyevj_bufferSize(cusolverH, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, - n_rows, + static_cast(n_rows), eig_vectors, - n_cols, + static_cast(n_cols), eig_vals, &lwork, syevj_params)); @@ -264,9 +264,9 @@ void eigJacobi(const raft::handle_t& handle, RAFT_CUSOLVER_TRY(cusolverDnsyevj(cusolverH, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, - n_rows, + static_cast(n_rows), eig_vectors, - n_cols, + static_cast(n_cols), eig_vals, d_work.data(), lwork, diff --git a/cpp/include/raft/linalg/detail/gemm.hpp b/cpp/include/raft/linalg/detail/gemm.hpp index 5742048864..baa066984b 100644 --- a/cpp/include/raft/linalg/detail/gemm.hpp +++ b/cpp/include/raft/linalg/detail/gemm.hpp @@ -148,7 +148,7 @@ void gemm(const raft::handle_t& handle, handle, a, n_rows_a, n_cols_a, b, c, n_rows_c, n_cols_c, trans_a, trans_b, alpha, beta, stream); } -template +template void gemm(const raft::handle_t& handle, T* z, T* x, @@ -160,10 +160,11 @@ void gemm(const raft::handle_t& handle, bool isXColMajor, bool isYColMajor, cudaStream_t stream, - T alpha = T(1.0), - T beta = T(0.0)) + T* alpha, + T* beta) { cublasHandle_t cublas_h = handle.get_cublas_handle(); + cublas_device_pointer_mode pmode(cublas_h); cublasOperation_t trans_a, trans_b; T *a, *b, *c; @@ -233,7 +234,7 @@ void gemm(const raft::handle_t& handle, } // Actual cuBLAS call RAFT_CUBLAS_TRY( - cublasgemm(cublas_h, trans_a, trans_b, M, N, K, &alpha, a, lda, b, ldb, &beta, c, ldc, stream)); + cublasgemm(cublas_h, trans_a, trans_b, M, N, K, alpha, a, lda, b, ldb, beta, c, ldc, stream)); } } // namespace detail diff --git a/cpp/include/raft/linalg/detail/map.cuh b/cpp/include/raft/linalg/detail/map.cuh index 2c73521887..add003eb52 100644 --- a/cpp/include/raft/linalg/detail/map.cuh +++ b/cpp/include/raft/linalg/detail/map.cuh @@ -25,20 +25,30 @@ namespace raft { namespace linalg { namespace detail { -template -__global__ void mapKernel(OutType* out, size_t len, MapOp map, const InType* in, Args... args) +template +__global__ void mapKernel(OutType* out, IdxType len, MapOp map, const InType* in, Args... args) { auto idx = (threadIdx.x + (blockIdx.x * blockDim.x)); if (idx < len) { out[idx] = map(in[idx], args[idx]...); } } -template +template void mapImpl( - OutType* out, size_t len, MapOp map, cudaStream_t stream, const InType* in, Args... args) + OutType* out, IdxType len, MapOp map, cudaStream_t stream, const InType* in, Args... args) { - const int nblks = raft::ceildiv(len, (size_t)TPB); - mapKernel + const int nblks = raft::ceildiv(len, (IdxType)TPB); + mapKernel <<>>(out, len, map, in, args...); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/include/raft/linalg/detail/map_then_reduce.cuh b/cpp/include/raft/linalg/detail/map_then_reduce.cuh index 9c0a21ee5c..7ef9ca1c43 100644 --- a/cpp/include/raft/linalg/detail/map_then_reduce.cuh +++ b/cpp/include/raft/linalg/detail/map_then_reduce.cuh @@ -48,12 +48,13 @@ __device__ void reduce(OutType* out, const InType acc, ReduceLambda op) template __global__ void mapThenReduceKernel(OutType* out, - size_t len, + IdxType len, OutType neutral, MapOp map, ReduceLambda op, @@ -72,12 +73,13 @@ __global__ void mapThenReduceKernel(OutType* out, template void mapThenReduceImpl(OutType* out, - size_t len, + IdxType len, OutType neutral, MapOp map, ReduceLambda op, @@ -86,8 +88,8 @@ void mapThenReduceImpl(OutType* out, Args... args) { raft::update_device(out, &neutral, 1, stream); - const int nblks = raft::ceildiv(len, (size_t)TPB); - mapThenReduceKernel + const int nblks = raft::ceildiv(len, IdxType(TPB)); + mapThenReduceKernel <<>>(out, len, neutral, map, op, in, args...); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/include/raft/linalg/detail/multiply.cuh b/cpp/include/raft/linalg/detail/multiply.cuh index ec3ec802de..f1a8548bfa 100644 --- a/cpp/include/raft/linalg/detail/multiply.cuh +++ b/cpp/include/raft/linalg/detail/multiply.cuh @@ -23,7 +23,8 @@ namespace linalg { namespace detail { template -void multiplyScalar(math_t* out, const math_t* in, math_t scalar, IdxType len, cudaStream_t stream) +void multiplyScalar( + math_t* out, const math_t* in, const math_t scalar, IdxType len, cudaStream_t stream) { raft::linalg::unaryOp( out, in, len, [scalar] __device__(math_t in) { return in * scalar; }, stream); diff --git a/cpp/include/raft/linalg/detail/reduce.cuh b/cpp/include/raft/linalg/detail/reduce.cuh index f64631689a..cc86716a8d 100644 --- a/cpp/include/raft/linalg/detail/reduce.cuh +++ b/cpp/include/raft/linalg/detail/reduce.cuh @@ -32,8 +32,8 @@ template > void reduce(OutType* dots, const InType* data, - int D, - int N, + IdxType D, + IdxType N, OutType init, bool rowMajor, bool alongRows, diff --git a/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh b/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh index 007c05c0d4..9ddcbae20b 100644 --- a/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh +++ b/cpp/include/raft/linalg/detail/reduce_rows_by_key.cuh @@ -95,16 +95,16 @@ struct quadSum { template __launch_bounds__(SUM_ROWS_SMALL_K_DIMX, 4) - __global__ void sum_rows_by_key_small_nkeys_kernel(const DataIteratorT d_A, + __global__ void sum_rows_by_key_small_nkeys_kernel(const DataIteratorT* d_A, int lda, const char* d_keys, const WeightT* d_weights, int nrows, int ncols, int nkeys, - DataIteratorT d_sums) + DataIteratorT* d_sums) { - typedef typename std::iterator_traits::value_type DataType; + typedef typename std::iterator_traits::value_type DataType; typedef cub::BlockReduce, SUM_ROWS_SMALL_K_DIMX> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -158,14 +158,14 @@ __launch_bounds__(SUM_ROWS_SMALL_K_DIMX, 4) } template -void sum_rows_by_key_small_nkeys(const DataIteratorT d_A, +void sum_rows_by_key_small_nkeys(const DataIteratorT* d_A, int lda, const char* d_keys, const WeightT* d_weights, int nrows, int ncols, int nkeys, - DataIteratorT d_sums, + DataIteratorT* d_sums, cudaStream_t st) { dim3 grid, block; @@ -189,18 +189,18 @@ void sum_rows_by_key_small_nkeys(const DataIteratorT d_A, #define SUM_ROWS_BY_KEY_LARGE_K_MAX_K 1024 template -__global__ void sum_rows_by_key_large_nkeys_kernel_colmajor(const DataIteratorT d_A, +__global__ void sum_rows_by_key_large_nkeys_kernel_colmajor(const DataIteratorT* d_A, int lda, - const KeysIteratorT d_keys, + KeysIteratorT d_keys, const WeightT* d_weights, int nrows, int ncols, int key_offset, int nkeys, - DataIteratorT d_sums) + DataIteratorT* d_sums) { typedef typename std::iterator_traits::value_type KeyType; - typedef typename std::iterator_traits::value_type DataType; + typedef typename std::iterator_traits::value_type DataType; __shared__ DataType local_sums[SUM_ROWS_BY_KEY_LARGE_K_MAX_K]; for (int local_key = threadIdx.x; local_key < nkeys; local_key += blockDim.x) @@ -238,14 +238,14 @@ __global__ void sum_rows_by_key_large_nkeys_kernel_colmajor(const DataIteratorT } template -void sum_rows_by_key_large_nkeys_colmajor(const DataIteratorT d_A, +void sum_rows_by_key_large_nkeys_colmajor(const DataIteratorT* d_A, int lda, KeysIteratorT d_keys, int nrows, int ncols, int key_offset, int nkeys, - DataIteratorT d_sums, + DataIteratorT* d_sums, cudaStream_t st) { dim3 grid, block; @@ -264,7 +264,7 @@ void sum_rows_by_key_large_nkeys_colmajor(const DataIteratorT d_A, //#define RRBK_SHMEM template -__global__ void sum_rows_by_key_large_nkeys_kernel_rowmajor(const DataIteratorT d_A, +__global__ void sum_rows_by_key_large_nkeys_kernel_rowmajor(const DataIteratorT* d_A, int lda, const WeightT* d_weights, KeysIteratorT d_keys, @@ -272,10 +272,10 @@ __global__ void sum_rows_by_key_large_nkeys_kernel_rowmajor(const DataIteratorT int ncols, int key_offset, int nkeys, - DataIteratorT d_sums) + DataIteratorT* d_sums) { typedef typename std::iterator_traits::value_type KeyType; - typedef typename std::iterator_traits::value_type DataType; + typedef typename std::iterator_traits::value_type DataType; #ifdef RRBK_SHMEM __shared__ KeyType sh_keys[RRBK_SHMEM_SZ]; @@ -320,15 +320,15 @@ __global__ void sum_rows_by_key_large_nkeys_kernel_rowmajor(const DataIteratorT } template -void sum_rows_by_key_large_nkeys_rowmajor(const DataIteratorT d_A, +void sum_rows_by_key_large_nkeys_rowmajor(const DataIteratorT* d_A, int lda, - const KeysIteratorT d_keys, + KeysIteratorT d_keys, const WeightT* d_weights, int nrows, int ncols, int key_offset, int nkeys, - DataIteratorT d_sums, + DataIteratorT* d_sums, cudaStream_t st) { // x-dim refers to the column in the input data @@ -367,19 +367,19 @@ void sum_rows_by_key_large_nkeys_rowmajor(const DataIteratorT d_A, * @param[in] stream CUDA stream */ template -void reduce_rows_by_key(const DataIteratorT d_A, +void reduce_rows_by_key(const DataIteratorT* d_A, int lda, - const KeysIteratorT d_keys, + KeysIteratorT d_keys, const WeightT* d_weights, char* d_keys_char, int nrows, int ncols, int nkeys, - DataIteratorT d_sums, + DataIteratorT* d_sums, cudaStream_t stream) { typedef typename std::iterator_traits::value_type KeyType; - typedef typename std::iterator_traits::value_type DataType; + typedef typename std::iterator_traits::value_type DataType; // Following kernel needs memset cudaMemsetAsync(d_sums, 0, ncols * nkeys * sizeof(DataType), stream); @@ -418,17 +418,17 @@ void reduce_rows_by_key(const DataIteratorT d_A, * @param[in] stream CUDA stream */ template -void reduce_rows_by_key(const DataIteratorT d_A, +void reduce_rows_by_key(const DataIteratorT* d_A, int lda, - const KeysIteratorT d_keys, + KeysIteratorT d_keys, char* d_keys_char, int nrows, int ncols, int nkeys, - DataIteratorT d_sums, + DataIteratorT* d_sums, cudaStream_t stream) { - typedef typename std::iterator_traits::value_type DataType; + typedef typename std::iterator_traits::value_type DataType; reduce_rows_by_key(d_A, lda, d_keys, diff --git a/cpp/include/raft/linalg/detail/ternary_op.cuh b/cpp/include/raft/linalg/detail/ternary_op.cuh index 46a5385d51..7874f20f56 100644 --- a/cpp/include/raft/linalg/detail/ternary_op.cuh +++ b/cpp/include/raft/linalg/detail/ternary_op.cuh @@ -22,9 +22,9 @@ namespace raft { namespace linalg { namespace detail { -template +template __global__ void ternaryOpKernel( - math_t* out, const math_t* in1, const math_t* in2, const math_t* in3, IdxType len, Lambda op) + out_t* out, const math_t* in1, const math_t* in2, const math_t* in3, IdxType len, Lambda op) { typedef raft::TxN_t VecType; VecType a, b, c; @@ -41,8 +41,8 @@ __global__ void ternaryOpKernel( a.store(out, idx); } -template -void ternaryOpImpl(math_t* out, +template +void ternaryOpImpl(out_t* out, const math_t* in1, const math_t* in2, const math_t* in3, @@ -51,7 +51,7 @@ void ternaryOpImpl(math_t* out, cudaStream_t stream) { const IdxType nblks = raft::ceildiv(veclen_ ? len / veclen_ : len, (IdxType)TPB); - ternaryOpKernel + ternaryOpKernel <<>>(out, in1, in2, in3, len, op); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -70,8 +70,8 @@ void ternaryOpImpl(math_t* out, * @param op the device-lambda * @param stream cuda stream where to launch work */ -template -void ternaryOp(math_t* out, +template +void ternaryOp(out_t* out, const math_t* in1, const math_t* in2, const math_t* in3, @@ -81,22 +81,22 @@ void ternaryOp(math_t* out, { size_t bytes = len * sizeof(math_t); if (16 / sizeof(math_t) && bytes % 16 == 0) { - ternaryOpImpl( + ternaryOpImpl( out, in1, in2, in3, len, op, stream); } else if (8 / sizeof(math_t) && bytes % 8 == 0) { - ternaryOpImpl( + ternaryOpImpl( out, in1, in2, in3, len, op, stream); } else if (4 / sizeof(math_t) && bytes % 4 == 0) { - ternaryOpImpl( + ternaryOpImpl( out, in1, in2, in3, len, op, stream); } else if (2 / sizeof(math_t) && bytes % 2 == 0) { - ternaryOpImpl( + ternaryOpImpl( out, in1, in2, in3, len, op, stream); } else if (1 / sizeof(math_t)) { - ternaryOpImpl( + ternaryOpImpl( out, in1, in2, in3, len, op, stream); } else { - ternaryOpImpl(out, in1, in2, in3, len, op, stream); + ternaryOpImpl(out, in1, in2, in3, len, op, stream); } } diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index 4f65544058..ef5551ea7e 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -18,8 +18,8 @@ #include "cublas_wrappers.hpp" +#include #include -#include #include #include #include diff --git a/cpp/include/raft/linalg/divide.cuh b/cpp/include/raft/linalg/divide.cuh index 820c42f0ea..53b083045e 100644 --- a/cpp/include/raft/linalg/divide.cuh +++ b/cpp/include/raft/linalg/divide.cuh @@ -20,6 +20,9 @@ #include "detail/divide.cuh" +#include +#include + namespace raft { namespace linalg { @@ -27,7 +30,8 @@ using detail::divides_scalar; /** * @defgroup ScalarOps Scalar operations on the input buffer - * @tparam math_t data-type upon which the math operation will be performed + * @tparam OutT output data-type upon which the math operation will be performed + * @tparam InT input data-type upon which the math operation will be performed * @tparam IdxType Integer type used to for addressing * @param out the output buffer * @param in the input buffer @@ -36,13 +40,62 @@ using detail::divides_scalar; * @param stream cuda stream where to launch work * @{ */ -template -void divideScalar(math_t* out, const math_t* in, math_t scalar, IdxType len, cudaStream_t stream) +template +void divideScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_t stream) { detail::divideScalar(out, in, scalar, len, stream); } /** @} */ +/** + * @defgroup divide Division Arithmetic + * @{ + */ + +/** + * @brief Elementwise division of input by host scalar + * @tparam InType Input Type raft::device_mdspan + * @tparam OutType Output Type raft::device_mdspan + * @tparam ScalarIdxType Index Type of scalar + * @param[in] handle raft::handle_t + * @param[in] in Input + * @param[in] scalar raft::host_scalar_view + * @param[out] out Output + */ +template , + typename = raft::enable_if_output_device_mdspan> +void divide_scalar(const raft::handle_t& handle, + InType in, + OutType out, + raft::host_scalar_view scalar) +{ + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input"); + + if (out.size() <= std::numeric_limits::max()) { + divideScalar(out.data_handle(), + in.data_handle(), + *scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } else { + divideScalar(out.data_handle(), + in.data_handle(), + *scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } +} + +/** @} */ // end of group add + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/eig.cuh b/cpp/include/raft/linalg/eig.cuh index f1f02dc13e..2ad222d42d 100644 --- a/cpp/include/raft/linalg/eig.cuh +++ b/cpp/include/raft/linalg/eig.cuh @@ -20,6 +20,8 @@ #include "detail/eig.cuh" +#include + namespace raft { namespace linalg { @@ -73,9 +75,9 @@ using detail::OVERWRITE_INPUT; template void eigSelDC(const raft::handle_t& handle, math_t* in, - int n_rows, - int n_cols, - int n_eig_vals, + std::size_t n_rows, + std::size_t n_cols, + std::size_t n_eig_vals, math_t* eig_vectors, math_t* eig_vals, EigVecMemUsage memUsage, @@ -102,8 +104,8 @@ void eigSelDC(const raft::handle_t& handle, template void eigJacobi(const raft::handle_t& handle, const math_t* in, - int n_rows, - int n_cols, + std::size_t n_rows, + std::size_t n_cols, math_t* eig_vectors, math_t* eig_vals, cudaStream_t stream, @@ -112,6 +114,109 @@ void eigJacobi(const raft::handle_t& handle, { detail::eigJacobi(handle, in, n_rows, n_cols, eig_vectors, eig_vals, stream, tol, sweeps); } + +/** + * @brief eig decomp with divide and conquer method for the column-major + * symmetric matrices + * @tparam ValueType the data-type of input and output + * @tparam IntegerType Integer used for addressing + * @param handle raft::handle_t + * @param[in] in input raft::device_matrix_view (symmetric matrix that has real eig values and + * vectors) + * @param[out] eig_vectors: eigenvectors output of type raft::device_matrix_view + * @param[out] eig_vals: eigen values output of type raft::device_vector_view + */ +template +void eig_dc(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view eig_vectors, + raft::device_vector_view eig_vals) +{ + RAFT_EXPECTS(in.size() == eig_vectors.size(), "Size mismatch between Input and Eigen Vectors"); + RAFT_EXPECTS(eig_vals.size() == in.extent(1), "Size mismatch between Input and Eigen Values"); + + eigDC(handle, + in.data_handle(), + in.extent(0), + in.extent(1), + eig_vectors.data_handle(), + eig_vals.data_handle(), + handle.get_stream()); +} + +/** + * @brief eig decomp to select top-n eigen values with divide and conquer method + * for the column-major symmetric matrices + * @tparam ValueType the data-type of input and output + * @tparam IntegerType Integer used for addressing + * @param[in] handle raft::handle_t + * @param[in] in input raft::device_matrix_view (symmetric matrix that has real eig values and + * vectors) + * @param[out] eig_vectors: eigenvectors output of type raft::device_matrix_view + * @param[out] eig_vals: eigen values output of type raft::device_vector_view + * @param[in] n_eig_vals: number of eigenvectors to be generated + * @param[in] memUsage: the memory selection for eig vector output + */ +template +void eig_dc_selective(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view eig_vectors, + raft::device_vector_view eig_vals, + std::size_t n_eig_vals, + EigVecMemUsage memUsage) +{ + RAFT_EXPECTS(eig_vectors.size() == n_eig_vals * in.extent(0), + "Size mismatch between Input and Eigen Vectors"); + RAFT_EXPECTS(eig_vals.size() == n_eig_vals, "Size mismatch between Input and Eigen Values"); + + raft::linalg::eigSelDC(handle, + const_cast(in.data_handle()), + in.extent(0), + in.extent(1), + n_eig_vals, + eig_vectors.data_handle(), + eig_vals.data_handle(), + memUsage, + handle.get_stream()); +} + +/** + * @brief overloaded function for eig decomp with Jacobi method for the + * column-major symmetric matrices (in parameter) + * @tparam ValueType the data-type of input and output + * @tparam IntegerType Integer used for addressing + * @param handle raft::handle_t + * @param[in] in input raft::device_matrix_view (symmetric matrix that has real eig values and + * vectors) + * @param[out] eig_vectors: eigenvectors output of type raft::device_matrix_view + * @param[out] eig_vals: eigen values output of type raft::device_vector_view + * @param[in] tol: error tolerance for the jacobi method. Algorithm stops when the + Frobenius norm of the absolute error is below tol + * @param[in] sweeps: number of sweeps in the Jacobi algorithm. The more the better + * accuracy. + */ +template +void eig_jacobi(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view eig_vectors, + raft::device_vector_view eig_vals, + ValueType tol = 1.e-7, + int sweeps = 15) +{ + RAFT_EXPECTS(in.size() == eig_vectors.size(), "Size mismatch between Input and Eigen Vectors"); + RAFT_EXPECTS(eig_vals.size() == in.extent(1), "Size mismatch between Input and Eigen Values"); + + eigJacobi(handle, + in.data_handle(), + in.extent(0), + in.extent(1), + eig_vectors.data_handle(), + eig_vals.data_handle(), + handle.get_stream(), + tol, + sweeps); +} + /** @} */ // end of eig }; // end namespace linalg diff --git a/cpp/include/raft/linalg/gemm.cuh b/cpp/include/raft/linalg/gemm.cuh index 16a5bc48ea..f2354da6c6 100644 --- a/cpp/include/raft/linalg/gemm.cuh +++ b/cpp/include/raft/linalg/gemm.cuh @@ -20,6 +20,12 @@ #include "detail/gemm.hpp" +#include +#include +#include +#include +#include + namespace raft { namespace linalg { @@ -145,7 +151,7 @@ void gemm(const raft::handle_t& handle, * @param x input matrix of size M rows x K columns * @param y input matrix of size K rows x N columns * @param _M number of rows of X and Z - * @param _N number of rows of Y and columns of Z + * @param _N number of columns of Y and columns of Z * @param _K number of columns of X and rows of Y * @param isZColMajor Storage layout of Z. true = col major, false = row major * @param isXColMajor Storage layout of X. true = col major, false = row major @@ -170,9 +176,102 @@ void gemm(const raft::handle_t& handle, T beta = T(0.0)) { detail::gemm( - handle, z, x, y, _M, _N, _K, isZColMajor, isXColMajor, isYColMajor, stream, alpha, beta); + handle, z, x, y, _M, _N, _K, isZColMajor, isXColMajor, isYColMajor, stream, &alpha, &beta); +} + +/** + * @defgroup gemm Matrix-Matrix Multiplication + * @{ + */ + +/** + * @brief GEMM function designed for handling all possible + * combinations of operand layouts (raft::row_major or raft::col_major) + * with scalars alpha and beta on the host or device + * It computes the following equation: Z = alpha . X * Y + beta . Z + * If alpha is not provided, it is assumed to be 1.0 + * If beta is not provided, it is assumed to be 0.0 + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of index + * @tparam LayoutPolicyX layout of X + * @tparam LayoutPolicyY layout of Y + * @tparam LayoutPolicyZ layout of Z + * @param[in] handle raft handle + * @param[in] x input raft::device_matrix_view of size M rows x K columns + * @param[in] y input raft::device_matrix_view of size K rows x N columns + * @param[out] z output raft::device_matrix_view of size M rows x N columns + * @param[in] alpha optional raft::host_scalar_view or raft::device_scalar_view, default 1.0 + * @param[in] beta optional raft::host_scalar_view or raft::device_scalar_view, default 0.0 + */ +template , + typename = std::enable_if_t>, + std::is_same>>>> +void gemm(const raft::handle_t& handle, + raft::device_matrix_view x, + raft::device_matrix_view y, + raft::device_matrix_view z, + std::optional alpha = std::nullopt, + std::optional beta = std::nullopt) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(x), "X is not contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(y), "Y is not contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(z), "Z is not contiguous"); + + RAFT_EXPECTS(x.extent(0) == z.extent(0), "Number of rows of X and Z should be equal"); + RAFT_EXPECTS(y.extent(1) == z.extent(1), "Number of columns of Y and Z should be equal"); + RAFT_EXPECTS(x.extent(1) == y.extent(0), "Number of columns of X and rows of Y should be equal"); + + constexpr auto is_x_col_major = + std::is_same_v; + constexpr auto is_y_col_major = + std::is_same_v; + constexpr auto is_z_col_major = + std::is_same_v; + + constexpr auto device_mode = + std::is_same_v>; + + ValueType alpha_value = 1; + ValueType beta_value = 0; + + auto alpha_device = raft::make_device_scalar(handle, alpha_value); + auto beta_device = raft::make_device_scalar(handle, beta_value); + + auto alpha_host = raft::make_host_scalar(alpha_value); + auto beta_host = raft::make_host_scalar(beta_value); + + if constexpr (device_mode) { + if (!alpha) { alpha = alpha_device.view(); } + if (!beta) { beta = beta_device.view(); } + } else { + if (!alpha) { alpha = alpha_host.view(); } + if (!beta) { beta = beta_host.view(); } + } + + detail::gemm(handle, + z.data_handle(), + x.data_handle(), + y.data_handle(), + x.extent(0), + y.extent(1), + x.extent(1), + is_z_col_major, + is_x_col_major, + is_y_col_major, + handle.get_stream(), + alpha.value().data_handle(), + beta.value().data_handle()); } +/** @} */ // end of gemm + } // end namespace linalg } // end namespace raft diff --git a/cpp/include/raft/linalg/gemv.cuh b/cpp/include/raft/linalg/gemv.cuh index 26a6386148..8132a742f8 100644 --- a/cpp/include/raft/linalg/gemv.cuh +++ b/cpp/include/raft/linalg/gemv.cuh @@ -20,6 +20,12 @@ #include "detail/gemv.hpp" +#include +#include +#include +#include +#include + namespace raft { namespace linalg { @@ -206,6 +212,98 @@ void gemv(const raft::handle_t& handle, detail::gemv(handle, A, n_rows_a, n_cols_a, lda, x, y, trans_a, stream); } +/** + * @defgroup gemv Matrix-Vector Multiplication + * @{ + */ + +/** + * @brief GEMV function designed for raft::col_major layout for A + * It computes y = alpha * op(A) * x + beta * y, where length of y is number + * of rows in A while length of x is number of columns in A + * If layout for A is provided as raft::row_major, then a transpose of A + * is used in the computation, where length of y is number of columns in A + * while length of x is number of rows in A + * If alpha is not provided, it is assumed to be 1.0 + * If beta is not provided, it is assumed to be 0.0 + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of index + * @tparam LayoutPolicyX layout of X + * @tparam LayoutPolicyY layout of Y + * @tparam LayoutPolicyZ layout of Z + * @param[in] handle raft handle + * @param[in] A input raft::device_matrix_view of size (M, N) + * @param[in] x input raft::device_matrix_view of size (N, 1) if A is raft::col_major, else (M, 1) + * @param[out] y output raft::device_matrix_view of size (M, 1) if A is raft::col_major, else (N, 1) + * @param[in] alpha optional raft::host_scalar_view or raft::device_scalar_view, default 1.0 + * @param[in] beta optional raft::host_scalar_view or raft::device_scalar_view, default 0.0 + */ +template , + typename = std::enable_if_t>, + std::is_same>>>> +void gemv(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view x, + raft::device_vector_view y, + std::optional alpha = std::nullopt, + std::optional beta = std::nullopt) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(A), "A is not contiguous"); + + constexpr auto is_A_col_major = + std::is_same_v; + + if (is_A_col_major) { + RAFT_EXPECTS(x.extent(0) == A.extent(1), + "Number of columns of A and length of x should be equal"); + RAFT_EXPECTS(y.extent(0) == A.extent(0), "Number of rows of A and length of y should be equal"); + } else { + RAFT_EXPECTS(x.extent(0) == A.extent(0), "Number of rows of A and length of x should be equal"); + RAFT_EXPECTS(y.extent(0) == A.extent(1), + "Number of columns of A and length of y should be equal"); + } + + constexpr auto device_mode = + std::is_same_v>; + + ValueType alpha_value = 1; + ValueType beta_value = 0; + + auto alpha_device = raft::make_device_scalar(handle, alpha_value); + auto beta_device = raft::make_device_scalar(handle, beta_value); + + auto alpha_host = raft::make_host_scalar(alpha_value); + auto beta_host = raft::make_host_scalar(beta_value); + + if constexpr (device_mode) { + if (!alpha) { alpha = alpha_device.view(); } + if (!beta) { beta = beta_device.view(); } + } else { + if (!alpha) { alpha = alpha_host.view(); } + if (!beta) { beta = beta_host.view(); } + } + + gemv(handle, + !is_A_col_major, + A.extent(0), + A.extent(1), + alpha.value().data_handle(), + A.data_handle(), + A.extent(0), + x.data_handle(), + 1, + beta.value().data_handle(), + y.data_handle(), + 1, + handle.get_stream()); +} +/** @} */ // end of gemv + }; // namespace linalg }; // namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/linalg_types.hpp b/cpp/include/raft/linalg/linalg_types.hpp new file mode 100644 index 0000000000..e50d3a8e79 --- /dev/null +++ b/cpp/include/raft/linalg/linalg_types.hpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::linalg { + +/** + * @brief Enum for reduction/broadcast where an operation is to be performed along + * a matrix's rows or columns + * + */ +enum class Apply { ALONG_ROWS, ALONG_COLUMNS }; + +/** + * @brief Enum for reduction/broadcast where an operation is to be performed along + * a matrix's rows or columns + * + */ +enum class FillMode { UPPER, LOWER }; + +} // end namespace raft::linalg \ No newline at end of file diff --git a/cpp/include/raft/linalg/lstsq.cuh b/cpp/include/raft/linalg/lstsq.cuh index 1a4c5cf704..7654812886 100644 --- a/cpp/include/raft/linalg/lstsq.cuh +++ b/cpp/include/raft/linalg/lstsq.cuh @@ -115,6 +115,135 @@ void lstsqQR(const raft::handle_t& handle, detail::lstsqQR(handle, A, n_rows, n_cols, b, w, stream); } +/** + * @defgroup lstsq Least Squares Methods + * @{ + */ + +/** + * @brief Solves the linear ordinary least squares problem `Aw = b` + * Via SVD decomposition of `A = U S Vt`. + * + * @tparam ValueType the data-type of input/output + * @param[in] handle raft::handle_t + * @param[inout] A input raft::device_matrix_view + * Warning: the content of this matrix is modified. + * @param[inout] b input target raft::device_vector_view + * Warning: the content of this vector is modified. + * @param[out] w output coefficient raft::device_vector_view + */ +template +void lstsq_svd_qr(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view b, + raft::device_vector_view w) +{ + RAFT_EXPECTS(A.extent(1) == w.size(), "Size mismatch between A and w"); + RAFT_EXPECTS(A.extent(0) == b.size(), "Size mismatch between A and b"); + + lstsqSvdQR(handle, + const_cast(A.data_handle()), + A.extent(0), + A.extent(1), + const_cast(b.data_handle()), + w.data_handle(), + handle.get_stream()); +} + +/** + * @brief Solves the linear ordinary least squares problem `Aw = b` + * Via SVD decomposition of `A = U S V^T` using Jacobi iterations. + * + * @tparam ValueType the data-type of input/output + * @param[in] handle raft::handle_t + * @param[inout] A input raft::device_matrix_view + * Warning: the content of this matrix is modified. + * @param[inout] b input target raft::device_vector_view + * Warning: the content of this vector is modified. + * @param[out] w output coefficient raft::device_vector_view + */ +template +void lstsq_svd_jacobi(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view b, + raft::device_vector_view w) +{ + RAFT_EXPECTS(A.extent(1) == w.size(), "Size mismatch between A and w"); + RAFT_EXPECTS(A.extent(0) == b.size(), "Size mismatch between A and b"); + + lstsqSvdJacobi(handle, + const_cast(A.data_handle()), + A.extent(0), + A.extent(1), + const_cast(b.data_handle()), + w.data_handle(), + handle.get_stream()); +} + +/** + * @brief Solves the linear ordinary least squares problem `Aw = b` + * via eigenvalue decomposition of `A^T * A` (covariance matrix for dataset A). + * (`w = (A^T A)^-1 A^T b`) + * + * @tparam ValueType the data-type of input/output + * @param[in] handle raft::handle_t + * @param[inout] A input raft::device_matrix_view + * Warning: the content of this matrix is modified by the cuSOLVER routines. + * @param[inout] b input target raft::device_vector_view + * Warning: the content of this vector is modified by the cuSOLVER routines. + * @param[out] w output coefficient raft::device_vector_view + */ +template +void lstsq_eig(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view b, + raft::device_vector_view w) +{ + RAFT_EXPECTS(A.extent(1) == w.size(), "Size mismatch between A and w"); + RAFT_EXPECTS(A.extent(0) == b.size(), "Size mismatch between A and b"); + + lstsqEig(handle, + const_cast(A.data_handle()), + A.extent(0), + A.extent(1), + const_cast(b.data_handle()), + w.data_handle(), + handle.get_stream()); +} + +/** + * @brief Solves the linear ordinary least squares problem `Aw = b` + * via QR decomposition of `A = QR`. + * (triangular system of equations `Rw = Q^T b`) + * + * @tparam ValueType the data-type of input/output + * @param[in] handle raft::handle_t + * @param[inout] A input raft::device_matrix_view + * Warning: the content of this matrix is modified. + * @param[inout] b input target raft::device_vector_view + * Warning: the content of this vector is modified. + * @param[out] w output coefficient raft::device_vector_view + */ +template +void lstsq_qr(const raft::handle_t& handle, + raft::device_matrix_view A, + raft::device_vector_view b, + raft::device_vector_view w) +{ + RAFT_EXPECTS(A.extent(1) == w.size(), "Size mismatch between A and w"); + RAFT_EXPECTS(A.extent(0) == b.size(), "Size mismatch between A and b"); + + lstsqQR(handle, + const_cast(A.data_handle()), + A.extent(0), + A.extent(1), + const_cast(b.data_handle()), + w.data_handle(), + handle.get_stream()); +} + +/** @} */ // end of lstsq + }; // namespace linalg }; // namespace raft diff --git a/cpp/include/raft/linalg/map.cuh b/cpp/include/raft/linalg/map.cuh index 5df4d24b4f..ad35cc5880 100644 --- a/cpp/include/raft/linalg/map.cuh +++ b/cpp/include/raft/linalg/map.cuh @@ -20,6 +20,9 @@ #include "detail/map.cuh" +#include +#include + namespace raft { namespace linalg { @@ -37,17 +40,64 @@ namespace linalg { * @param in the input array * @param args additional input arrays */ +template +void map_k( + OutType* out, IdxType len, MapOp map, cudaStream_t stream, const InType* in, Args... args) +{ + detail::mapImpl( + out, len, map, stream, in, args...); +} +/** + * @defgroup map Mapping ops + * @{ + */ + +/** + * @brief CUDA version of map + * @tparam InType data-type for math operation of type raft::device_mdspan + * @tparam MapOp the device-lambda performing the actual operation + * @tparam TPB threads-per-block in the final kernel launched + * @tparam OutType data-type of result of type raft::device_mdspan + * @tparam Args additional parameters + * @param[in] handle raft::handle_t + * @param[in] in the input of type raft::device_mdspan + * @param[out] out the output of the map operation of type raft::device_mdspan + * @param[in] map the device-lambda + * @param[in] args additional input arrays + */ template -void map(OutType* out, size_t len, MapOp map, cudaStream_t stream, const InType* in, Args... args) + typename = raft::enable_if_input_device_mdspan, + typename = raft::enable_if_output_device_mdspan> +void map(const raft::handle_t& handle, InType in, OutType out, MapOp map, Args... args) { - detail::mapImpl(out, len, map, stream, in, args...); + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output is not exhaustive"); + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input is not exhaustive"); + RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Input and Output"); + + if (out.size() <= std::numeric_limits::max()) { + map_k( + out.data_handle(), out.size(), map, handle.get_stream(), in.data_handle(), args...); + } else { + map_k( + out.data_handle(), out.size(), map, handle.get_stream(), in.data_handle(), args...); + } } +/** @} */ // end of map + } // namespace linalg }; // namespace raft diff --git a/cpp/include/raft/linalg/map_reduce.cuh b/cpp/include/raft/linalg/map_reduce.cuh new file mode 100644 index 0000000000..180ed128a1 --- /dev/null +++ b/cpp/include/raft/linalg/map_reduce.cuh @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __MAP_REDUCE_H +#define __MAP_REDUCE_H + +#pragma once + +#include "detail/map_then_reduce.cuh" + +#include + +namespace raft::linalg { + +/** + * @defgroup map_reduce Map-Reduce ops + * @{ + */ + +/** + * @brief CUDA version of map and then generic reduction operation + * @tparam Type data-type upon which the math operation will be performed + * @tparam MapOp the device-lambda performing the actual map operation + * @tparam ReduceLambda the device-lambda performing the actual reduction + * @tparam TPB threads-per-block in the final kernel launched + * @tparam Args additional parameters + * @param out the output reduced value (assumed to be a device pointer) + * @param len number of elements in the input array + * @param neutral The neutral element of the reduction operation. For example: + * 0 for sum, 1 for multiply, +Inf for Min, -Inf for Max + * @param map the device-lambda + * @param op the reduction device lambda + * @param stream cuda-stream where to launch this kernel + * @param in the input array + * @param args additional input arrays + */ + +template +void mapReduce(OutType* out, + size_t len, + OutType neutral, + MapOp map, + ReduceLambda op, + cudaStream_t stream, + const InType* in, + Args... args) +{ + detail::mapThenReduceImpl( + out, len, neutral, map, op, stream, in, args...); +} + +/** + * @brief CUDA version of map and then generic reduction operation + * @tparam InValueType the data-type of the input + * @tparam MapOp the device-lambda performing the actual map operation + * @tparam ReduceLambda the device-lambda performing the actual reduction + * @tparam IndexType the index type + * @tparam OutValueType the data-type of the output + * @tparam ScalarIdxType index type of scalar + * @tparam Args additional parameters + * @param[in] handle raft::handle_t + * @param[in] in the input of type raft::device_vector_view + * @param[in] neutral The neutral element of the reduction operation. For example: + * 0 for sum, 1 for multiply, +Inf for Min, -Inf for Max + * @param[out] out the output reduced value assumed to be a raft::device_scalar_view + * @param[in] map the fused device-lambda + * @param[in] op the fused reduction device lambda + * @param[in] args additional input arrays + */ +template +void map_reduce(const raft::handle_t& handle, + raft::device_vector_view in, + raft::device_scalar_view out, + OutValueType neutral, + MapOp map, + ReduceLambda op, + Args... args) +{ + mapReduce( + out.data_handle(), + in.extent(0), + neutral, + map, + op, + handle.get_stream(), + in.data_handle(), + args...); +} + +/** @} */ // end of map_reduce + +} // end namespace raft::linalg + +#endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/map_then_reduce.cuh b/cpp/include/raft/linalg/map_then_reduce.cuh index 36828cf154..a69ac6df36 100644 --- a/cpp/include/raft/linalg/map_then_reduce.cuh +++ b/cpp/include/raft/linalg/map_then_reduce.cuh @@ -39,13 +39,14 @@ namespace linalg { template void mapThenSumReduce( - OutType* out, size_t len, MapOp map, cudaStream_t stream, const InType* in, Args... args) + OutType* out, IdxType len, MapOp map, cudaStream_t stream, const InType* in, Args... args) { - detail::mapThenReduceImpl( + detail::mapThenReduceImpl( out, len, (OutType)0, map, detail::sum_tag(), stream, in, args...); } @@ -66,25 +67,27 @@ void mapThenSumReduce( * @param in the input array * @param args additional input arrays */ - template -void mapThenReduce(OutType* out, - size_t len, - OutType neutral, - MapOp map, - ReduceLambda op, - cudaStream_t stream, - const InType* in, - Args... args) +[[deprecated("Use function `mapReduce` from `raft/linalg/map_reduce.cuh")]] void mapThenReduce( + OutType* out, + size_t len, + OutType neutral, + MapOp map, + ReduceLambda op, + cudaStream_t stream, + const InType* in, + Args... args) { - detail::mapThenReduceImpl( + detail::mapThenReduceImpl( out, len, neutral, map, op, stream, in, args...); } + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index 56437313e3..1438a09bd3 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -19,6 +19,10 @@ #pragma once #include "detail/matrix_vector_op.cuh" +#include "linalg_types.hpp" + +#include +#include namespace raft { namespace linalg { @@ -99,6 +103,142 @@ void matrixVectorOp(Type* out, detail::matrixVectorOp(out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); } +/** + * @defgroup matrix_vector_op Matrix Vector Operations + * @{ + */ + +/** + * @brief Operations for all the columns or rows with a given vector. + * Caution : Threads process multiple elements to speed up processing. These + * are loaded in a single read thanks to type promotion. Faster processing + * would thus only be enabled when adresses are optimally aligned for it. + * Note : the function will also check that the size of the window of accesses + * is a multiple of the number of elements processed by a thread in order to + * enable faster processing + * @tparam InValueType the data-type of the input matrices and vectors + * @tparam LayoutPolicy the layout of input and output (raft::row_major or raft::col_major) + * @tparam Lambda a device function which represents a binary operator + * @tparam OutElementType the data-type of the output raft::matrix_view + * @tparam IndexType Integer used for addressing + * @tparam TPB threads per block of the cuda kernel launched + * @param[in] handle raft::handle_t + * @param[in] matrix input raft::matrix_view + * @param[in] vec vector raft::vector_view + * @param[out] out output raft::matrix_view + * @param[in] apply whether the broadcast of vector needs to happen along + * the rows of the matrix or columns using enum class raft::linalg::Apply + * @param[in] op the mathematical operation + */ +template +void matrix_vector_op(const raft::handle_t& handle, + raft::device_matrix_view matrix, + raft::device_vector_view vec, + raft::device_matrix_view out, + Apply apply, + Lambda op) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(matrix), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Input must be contiguous"); + RAFT_EXPECTS(out.size() == matrix.size(), "Size mismatch between Output and Input"); + + auto constexpr rowMajor = std::is_same_v; + auto bcastAlongRows = apply == Apply::ALONG_ROWS; + + if (bcastAlongRows) { + RAFT_EXPECTS(out.extent(1) == static_cast(vec.size()), + "Size mismatch between matrix and vector"); + } else { + RAFT_EXPECTS(out.extent(0) == static_cast(vec.size()), + "Size mismatch between matrix and vector"); + } + + matrixVectorOp(out.data_handle(), + matrix.data_handle(), + vec.data_handle(), + out.extent(1), + out.extent(0), + rowMajor, + bcastAlongRows, + op, + handle.get_stream()); +} + +/** + * @brief Operations for all the columns or rows with the given vectors. + * Caution : Threads process multiple elements to speed up processing. These + * are loaded in a single read thanks to type promotion. Faster processing + * would thus only be enabled when adresses are optimally aligned for it. + * Note : the function will also check that the size of the window of accesses + * is a multiple of the number of elements processed by a thread in order to + * enable faster processing + * @tparam InValueType the data-type of the input matrices and vectors + * @tparam LayoutPolicy the layout of input and output (raft::row_major or raft::col_major) + * @tparam Lambda a device function which represents a binary operator + * @tparam OutElementType the data-type of the output raft::matrix_view + * @tparam IndexType Integer used for addressing + * @tparam TPB threads per block of the cuda kernel launched + * @param handle raft::handle_t + * @param matrix input raft::matrix_view + * @param vec1 the first vector raft::vector_view + * @param vec2 the second vector raft::vector_view + * @param out output raft::matrix_view + * @param apply whether the broadcast of vector needs to happen along + * the rows of the matrix or columns using enum class raft::linalg::Apply + * @param op the mathematical operation + */ +template +void matrix_vector_op(const raft::handle_t& handle, + raft::device_matrix_view matrix, + raft::device_vector_view vec1, + raft::device_vector_view vec2, + raft::device_matrix_view out, + Apply apply, + Lambda op) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(matrix), "Input must be contiguous"); + RAFT_EXPECTS(out.size() == matrix.size(), "Size mismatch between Output and Input"); + + auto constexpr rowMajor = std::is_same_v; + auto bcastAlongRows = apply == Apply::ALONG_ROWS; + + if (bcastAlongRows) { + RAFT_EXPECTS(out.extent(1) == static_cast(vec1.size()), + "Size mismatch between matrix and vector"); + RAFT_EXPECTS(out.extent(1) == static_cast(vec2.size()), + "Size mismatch between matrix and vector"); + } else { + RAFT_EXPECTS(out.extent(0) == static_cast(vec1.size()), + "Size mismatch between matrix and vector"); + RAFT_EXPECTS(out.extent(0) == static_cast(vec2.size()), + "Size mismatch between matrix and vector"); + } + + matrixVectorOp(out.data_handle(), + matrix.data_handle(), + vec1.data_handle(), + vec2.data_handle(), + out.extent(1), + out.extent(0), + rowMajor, + bcastAlongRows, + op, + handle.get_stream()); +} + +/** @} */ // end of group matrix_vector_op + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/mean_squared_error.cuh b/cpp/include/raft/linalg/mean_squared_error.cuh index 1b3297f926..ddfe58dad7 100644 --- a/cpp/include/raft/linalg/mean_squared_error.cuh +++ b/cpp/include/raft/linalg/mean_squared_error.cuh @@ -34,13 +34,45 @@ namespace linalg { * @param weight weight to apply to every term in the mean squared error calculation * @param stream cuda-stream where to launch this kernel */ -template +template void meanSquaredError( math_t* out, const math_t* A, const math_t* B, size_t len, math_t weight, cudaStream_t stream) { detail::meanSquaredError(out, A, B, len, weight, stream); } +/** + * @defgroup mean_squared_error Mean Squared Error + * @{ + */ + +/** + * @brief CUDA version mean squared error function mean((A-B)**2) + * @tparam InValueType Input data-type + * @tparam IndexType Input/Output index type + * @tparam OutValueType Output data-type + * @tparam TPB threads-per-block + * @param[in] handle raft::handle_t + * @param[in] A input raft::device_vector_view + * @param[in] B input raft::device_vector_view + * @param[out] out the output mean squared error value of type raft::device_scalar_view + * @param[in] weight weight to apply to every term in the mean squared error calculation + */ +template +void mean_squared_error(const raft::handle_t& handle, + raft::device_vector_view A, + raft::device_vector_view B, + raft::device_scalar_view out, + OutValueType weight) +{ + RAFT_EXPECTS(A.size() == B.size(), "Size mismatch between inputs"); + + meanSquaredError( + out.data_handle(), A.data_handle(), B.data_handle(), A.extent(0), weight, stream); +} + +/** @} */ // end of group mean_squared_error + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/multiply.cuh b/cpp/include/raft/linalg/multiply.cuh index f1161b23cb..119cf667d1 100644 --- a/cpp/include/raft/linalg/multiply.cuh +++ b/cpp/include/raft/linalg/multiply.cuh @@ -20,12 +20,17 @@ #include "detail/multiply.cuh" +#include +#include +#include + namespace raft { namespace linalg { /** * @defgroup ScalarOps Scalar operations on the input buffer - * @tparam math_t data-type upon which the math operation will be performed + * @tparam out_t data-type upon which the math operation will be performed + * @tparam in_t input data-type * @tparam IdxType Integer type used to for addressing * @param out the output buffer * @param in the input buffer @@ -34,13 +39,64 @@ namespace linalg { * @param stream cuda stream where to launch work * @{ */ -template -void multiplyScalar(math_t* out, const math_t* in, math_t scalar, IdxType len, cudaStream_t stream) +template +void multiplyScalar(out_t* out, const in_t* in, in_t scalar, IdxType len, cudaStream_t stream) { detail::multiplyScalar(out, in, scalar, len, stream); } /** @} */ +/** + * @defgroup multiply Multiplication Arithmetic + * @{ + */ + +/** + * @brief Element-wise multiplication of host scalar + * @tparam InType Input Type raft::device_mdspan + * @tparam OutType Output Type raft::device_mdspan + * @tparam ScalarIdxType Index Type of scalar + * @param[in] handle raft::handle_t + * @param[in] in the input buffer + * @param[out] out the output buffer + * @param[in] scalar the scalar used in the operations + * @{ + */ +template , + typename = raft::enable_if_output_device_mdspan> +void multiply_scalar( + const raft::handle_t& handle, + InType in, + OutType out, + raft::host_scalar_view scalar) +{ + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input"); + + if (out.size() <= std::numeric_limits::max()) { + multiplyScalar(out.data_handle(), + in.data_handle(), + *scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } else { + multiplyScalar(out.data_handle(), + in.data_handle(), + *scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } +} + +/** @} */ // end of group multiply + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index 87bd2a2b0a..389affef13 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -19,6 +19,10 @@ #pragma once #include "detail/norm.cuh" +#include "linalg_types.hpp" + +#include +#include namespace raft { namespace linalg { @@ -88,6 +92,61 @@ void colNorm(Type* dots, detail::colNormCaller(dots, data, D, N, type, rowMajor, stream, fin_op); } +/** + * @brief Compute norm of the input matrix and perform fin_op + * @tparam ElementType Input/Output data type + * @tparam LayoutPolicy the layout of input (raft::row_major or raft::col_major) + * @tparam IdxType Integer type used to for addressing + * @tparam Lambda device final lambda + * @param[in] handle raft::handle_t + * @param[in] in the input raft::device_matrix_view + * @param[out] out the output raft::device_vector_view + * @param[in] type the type of norm to be applied + * @param[in] apply Whether to apply the norm along rows (raft::linalg::Apply::ALONG_ROWS) + or along columns (raft::linalg::Apply::ALONG_COLUMNS) + * @param[in] fin_op the final lambda op + */ +template > +void norm(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view out, + NormType type, + Apply apply, + Lambda fin_op = raft::Nop()) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + + auto constexpr row_major = std::is_same_v; + auto along_rows = apply == Apply::ALONG_ROWS; + + if (along_rows) { + RAFT_EXPECTS(static_cast(out.size()) == in.extent(0), + "Output should be equal to number of rows in Input"); + rowNorm(out.data_handle(), + in.data_handle(), + in.extent(1), + in.extent(0), + type, + row_major, + handle.get_stream(), + fin_op); + } else { + RAFT_EXPECTS(static_cast(out.size()) == in.extent(1), + "Output should be equal to number of columns in Input"); + colNorm(out.data_handle(), + in.data_handle(), + in.extent(1), + in.extent(0), + type, + row_major, + handle.get_stream(), + fin_op); + } +} + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/power.cuh b/cpp/include/raft/linalg/power.cuh index 69f3e4d22b..acd226b71d 100644 --- a/cpp/include/raft/linalg/power.cuh +++ b/cpp/include/raft/linalg/power.cuh @@ -18,17 +18,19 @@ #pragma once +#include #include #include #include +#include namespace raft { namespace linalg { /** * @defgroup ScalarOps Scalar operations on the input buffer - * @tparam math_t data-type upon which the math operation will be performed - * @tparam IdxType Integer type used to for addressing + * @tparam in_t Input data-type + * @tparam out_t Output data-type * @param out the output buffer * @param in the input buffer * @param scalar the scalar used in the operations @@ -36,17 +38,18 @@ namespace linalg { * @param stream cuda stream where to launch work * @{ */ -template -void powerScalar(math_t* out, const math_t* in, math_t scalar, IdxType len, cudaStream_t stream) +template +void powerScalar(out_t* out, const in_t* in, const in_t scalar, IdxType len, cudaStream_t stream) { raft::linalg::unaryOp( - out, in, len, [scalar] __device__(math_t in) { return raft::myPow(in, scalar); }, stream); + out, in, len, [scalar] __device__(in_t in) { return raft::myPow(in, scalar); }, stream); } /** @} */ /** * @defgroup BinaryOps Element-wise binary operations on the input buffers - * @tparam math_t data-type upon which the math operation will be performed + * @tparam in_t Input data-type + * @tparam out_t Output data-type * @tparam IdxType Integer type used to for addressing * @param out the output buffer * @param in1 the first input buffer @@ -55,14 +58,103 @@ void powerScalar(math_t* out, const math_t* in, math_t scalar, IdxType len, cuda * @param stream cuda stream where to launch work * @{ */ -template -void power(math_t* out, const math_t* in1, const math_t* in2, IdxType len, cudaStream_t stream) +template +void power(out_t* out, const in_t* in1, const in_t* in2, IdxType len, cudaStream_t stream) { raft::linalg::binaryOp( - out, in1, in2, len, [] __device__(math_t a, math_t b) { return raft::myPow(a, b); }, stream); + out, in1, in2, len, [] __device__(in_t a, in_t b) { return raft::myPow(a, b); }, stream); } /** @} */ +/** + * @defgroup power Power Arithmetic + * @{ + */ + +/** + * @brief Elementwise power operation on the input buffers + * @tparam InType Input Type raft::device_mdspan + * @tparam OutType Output Type raft::device_mdspan + * @param[in] handle raft::handle_t + * @param[in] in1 First Input + * @param[in] in2 Second Input + * @param[out] out Output + */ +template , + typename = raft::enable_if_output_device_mdspan> +void power(const raft::handle_t& handle, InType in1, InType in2, OutType out) +{ + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in2), "Input 2 must be contiguous"); + RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(), + "Size mismatch between Output and Inputs"); + + if (out.size() <= std::numeric_limits::max()) { + power(out.data_handle(), + in1.data_handle(), + in2.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } else { + power(out.data_handle(), + in1.data_handle(), + in2.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } +} + +/** + * @brief Elementwise power of host scalar to input + * @tparam InType Input Type raft::device_mdspan + * @tparam OutType Output Type raft::device_mdspan + * @tparam ScalarIdxType Index Type of scalar + * @param[in] handle raft::handle_t + * @param[in] in Input + * @param[out] out Output + * @param[in] scalar raft::host_scalar_view + */ +template , + typename = raft::enable_if_output_device_mdspan> +void power_scalar( + const raft::handle_t& handle, + InType in, + OutType out, + const raft::host_scalar_view scalar) +{ + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input"); + + if (out.size() <= std::numeric_limits::max()) { + powerScalar(out.data_handle(), + in.data_handle(), + *scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } else { + powerScalar(out.data_handle(), + in.data_handle(), + *scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } +} + +/** @} */ // end of group add + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/qr.cuh b/cpp/include/raft/linalg/qr.cuh index fe6a5263ca..7e6e14e680 100644 --- a/cpp/include/raft/linalg/qr.cuh +++ b/cpp/include/raft/linalg/qr.cuh @@ -36,7 +36,6 @@ namespace linalg { * @param n_rows: number rows of input matrix * @param n_cols: number columns of input matrix * @param stream cuda stream - * @{ */ template void qrGetQ(const raft::handle_t& handle, @@ -70,6 +69,47 @@ void qrGetQR(const raft::handle_t& handle, { detail::qrGetQR(handle, M, Q, R, n_rows, n_cols, stream); } + +/** + * @brief Compute the QR decomposition of matrix M and return only the Q matrix. + * @param[in] handle raft::handle_t + * @param[in] M Input raft::device_matrix_view + * @param[out] Q Output raft::device_matrix_view + */ +template +void qr_get_q(const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_matrix_view Q) +{ + RAFT_EXPECTS(Q.size() == M.size(), "Size mismatch between Output and Input"); + + qrGetQ(handle, M.data_handle(), Q.data_handle(), M.extent(0), M.extent(1), handle.get_stream()); +} + +/** + * @brief Compute the QR decomposition of matrix M and return both the Q and R matrices. + * @param[in] handle raft::handle_t + * @param[in] M Input raft::device_matrix_view + * @param[in] Q Output raft::device_matrix_view + * @param[out] R Output raft::device_matrix_view + */ +template +void qr_get_qr(const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_matrix_view Q, + raft::device_matrix_view R) +{ + RAFT_EXPECTS(Q.size() == M.size(), "Size mismatch between Output and Input"); + + qrGetQR(handle, + M.data_handle(), + Q.data_handle(), + R.data_handle(), + M.extent(0), + M.extent(1), + handle.get_stream()); +} + /** @} */ }; // namespace linalg diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index 7640da8c2d..9c349ccb4f 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -19,6 +19,10 @@ #pragma once #include "detail/reduce.cuh" +#include "linalg_types.hpp" + +#include +#include namespace raft { namespace linalg { @@ -60,8 +64,8 @@ template > void reduce(OutType* dots, const InType* data, - int D, - int N, + IdxType D, + IdxType N, OutType init, bool rowMajor, bool alongRows, @@ -75,6 +79,87 @@ void reduce(OutType* dots, dots, data, D, N, init, rowMajor, alongRows, stream, inplace, main_op, reduce_op, final_op); } +/** + * @defgroup reduction Reduction Along Requested Dimension + * @{ + */ + +/** + * @brief Compute reduction of the input matrix along the requested dimension + * This API computes a reduction of a matrix whose underlying storage + * is either row-major or column-major, while allowing the choose the + * dimension for reduction. Depending upon the dimension chosen for + * reduction, the memory accesses may be coalesced or strided. + * + * @tparam InElementType the input data-type of underlying raft::matrix_view + * @tparam LayoutPolicy The layout of Input/Output (row or col major) + * @tparam OutElementType the output data-type of underlying raft::matrix_view and reduction + * @tparam IndexType Integer type used to for addressing + * @tparam MainLambda Unary lambda applied while acculumation (eg: L1 or L2 norm) + * It must be a 'callable' supporting the following input and output: + *
OutType (*MainLambda)(InType, IdxType);
+ * @tparam ReduceLambda Binary lambda applied for reduction (eg: addition(+) for L2 norm) + * It must be a 'callable' supporting the following input and output: + *
OutType (*ReduceLambda)(OutType);
+ * @tparam FinalLambda the final lambda applied before STG (eg: Sqrt for L2 norm) + * It must be a 'callable' supporting the following input and output: + *
OutType (*FinalLambda)(OutType);
+ * @param[in] handle raft::handle_t + * @param[in] data Input of type raft::device_matrix_view + * @param[out] dots Output of type raft::device_matrix_view + * @param[in] init initial value to use for the reduction + * @param[in] apply whether to reduce along rows or along columns (using raft::linalg::Apply) + * @param[in] main_op fused elementwise operation to apply before reduction + * @param[in] reduce_op fused binary reduction operation + * @param[in] final_op fused elementwise operation to apply before storing results + * @param[in] inplace reduction result added inplace or overwrites old values? + */ +template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::Nop> +void reduce(const raft::handle_t& handle, + raft::device_matrix_view data, + raft::device_vector_view dots, + OutElementType init, + Apply apply, + bool inplace = false, + MainLambda main_op = raft::Nop(), + ReduceLambda reduce_op = raft::Sum(), + FinalLambda final_op = raft::Nop()) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(data), "Input must be contiguous"); + + auto constexpr row_major = std::is_same_v; + bool along_rows = apply == Apply::ALONG_ROWS; + + if (along_rows) { + RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), + "Output should be equal to number of columns in Input"); + } else { + RAFT_EXPECTS(static_cast(dots.size()) == data.extent(0), + "Output should be equal to number of rows in Input"); + } + + reduce(dots.data_handle(), + data.data_handle(), + data.extent(1), + data.extent(0), + init, + row_major, + along_rows, + handle.get_stream(), + inplace, + main_op, + reduce_op, + final_op); +} + +/** @} */ // end of group reduction + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/reduce_cols_by_key.cuh b/cpp/include/raft/linalg/reduce_cols_by_key.cuh index 2336639258..a7917f21f8 100644 --- a/cpp/include/raft/linalg/reduce_cols_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_cols_by_key.cuh @@ -18,7 +18,10 @@ #pragma once -#include +#include "detail/reduce_cols_by_key.cuh" + +#include +#include namespace raft { namespace linalg { @@ -52,6 +55,52 @@ void reduce_cols_by_key(const T* data, { detail::reduce_cols_by_key(data, keys, out, nrows, ncols, nkeys, stream); } + +/** + * @defgroup reduce_cols_by_key Reduce Across Columns by Key + * @{ + */ + +/** + * @brief Computes the sum-reduction of matrix columns for each given key + * TODO: Support generic reduction lambdas https://github.com/rapidsai/raft/issues/860 + * @tparam ElementType the input data type (as well as the output reduced matrix) + * @tparam KeyType data type of the keys + * @tparam IndexType indexing arithmetic type + * @param[in] handle raft::handle_t + * @param[in] data the input data (dim = nrows x ncols). This is assumed to be in + * row-major layout of type raft::device_matrix_view + * @param[in] keys keys raft::device_vector_view (len = ncols). It is assumed that each key in this + * array is between [0, nkeys). In case this is not true, the caller is expected + * to have called make_monotonic primitive to prepare such a contiguous and + * monotonically increasing keys array. + * @param[out] out the output reduced raft::device_matrix_view along columns (dim = nrows x nkeys). + * This will be assumed to be in row-major layout + * @param[in] nkeys number of unique keys in the keys array + */ +template +void reduce_cols_by_key( + const raft::handle_t& handle, + raft::device_matrix_view data, + raft::device_vector_view keys, + raft::device_matrix_view out, + IndexType nkeys) +{ + RAFT_EXPECTS(out.extent(0) == data.extent(0) && out.extent(1) == nkeys, + "Output is not of size nrows * nkeys"); + RAFT_EXPECTS(keys.extent(0) == data.extent(1), "Keys is not of size ncols"); + + reduce_cols_by_key(data.data_handle(), + keys.data_handle(), + out.data_handle(), + data.extent(0), + data.extent(1), + nkeys, + handle.get_stream()); +} + +/** @} */ // end of group reduce_cols_by_key + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/reduce_rows_by_key.cuh b/cpp/include/raft/linalg/reduce_rows_by_key.cuh index ca7a956986..39c54e8b0c 100644 --- a/cpp/include/raft/linalg/reduce_rows_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_rows_by_key.cuh @@ -18,7 +18,10 @@ #pragma once -#include +#include "detail/reduce_rows_by_key.cuh" + +#include +#include namespace raft { namespace linalg { @@ -53,7 +56,7 @@ void convert_array(IteratorT1 dst, IteratorT2 src, int n, cudaStream_t st) * @param[in] stream CUDA stream */ template -void reduce_rows_by_key(const DataIteratorT d_A, +void reduce_rows_by_key(const DataIteratorT* d_A, int lda, const KeysIteratorT d_keys, const WeightT* d_weights, @@ -61,7 +64,7 @@ void reduce_rows_by_key(const DataIteratorT d_A, int nrows, int ncols, int nkeys, - DataIteratorT d_sums, + DataIteratorT* d_sums, cudaStream_t stream) { detail::reduce_rows_by_key( @@ -85,17 +88,17 @@ void reduce_rows_by_key(const DataIteratorT d_A, * @param[in] stream CUDA stream */ template -void reduce_rows_by_key(const DataIteratorT d_A, +void reduce_rows_by_key(const DataIteratorT* d_A, int lda, - const KeysIteratorT d_keys, + KeysIteratorT d_keys, char* d_keys_char, int nrows, int ncols, int nkeys, - DataIteratorT d_sums, + DataIteratorT* d_sums, cudaStream_t stream) { - typedef typename std::iterator_traits::value_type DataType; + typedef typename std::iterator_traits::value_type DataType; reduce_rows_by_key(d_A, lda, d_keys, @@ -108,6 +111,69 @@ void reduce_rows_by_key(const DataIteratorT d_A, stream); } +/** + * @defgroup reduce_rows_by_key Reduce Across Rows by Key + * @{ + */ + +/** + * @brief Computes the weighted sum-reduction of matrix rows for each given key + * TODO: Support generic reduction lambdas https://github.com/rapidsai/raft/issues/860 + * @tparam ElementType data-type of input and output + * @tparam KeyType data-type of keys + * @tparam WeightType data-type of weights + * @tparam IndexType index type + * @param[in] handle raft::handle_t + * @param[in] d_A Input raft::device_mdspan (ncols * nrows) + * @param[in] d_keys Keys for each row raft::device_vector_view (1 x nrows) + * @param[out] d_sums Row sums by key raft::device_matrix_view (ncols x d_keys) + * @param[in] n_unique_keys Number of unique keys in d_keys + * @param[in] d_weights Weights for each observation in d_A raft::device_vector_view optional (1 + * x nrows) + * @param[out] d_keys_char Scratch memory for conversion of keys to char, raft::device_vector_view + */ +template +void reduce_rows_by_key( + const raft::handle_t& handle, + raft::device_matrix_view d_A, + raft::device_vector_view d_keys, + raft::device_matrix_view d_sums, + IndexType n_unique_keys, + raft::device_vector_view d_keys_char, + std::optional> d_weights = std::nullopt) +{ + RAFT_EXPECTS(d_A.extent(0) == d_A.extent(0) && d_sums.extent(1) == n_unique_keys, + "Output is not of size ncols * n_unique_keys"); + RAFT_EXPECTS(d_keys.extent(0) == d_A.extent(1), "Keys is not of size nrows"); + + if (d_weights) { + RAFT_EXPECTS(d_weights.value().extent(0) == d_A.extent(1), "Weights is not of size nrows"); + + reduce_rows_by_key(d_A.data_handle(), + d_A.extent(0), + d_keys.data_handle(), + d_weights.value().data_handle(), + d_keys_char.data_handle(), + d_A.extent(1), + d_A.extent(0), + n_unique_keys, + d_sums.data_handle(), + handle.get_stream()); + } else { + reduce_rows_by_key(d_A.data_handle(), + d_A.extent(0), + d_keys.data_handle(), + d_keys_char.data_handle(), + d_A.extent(1), + d_A.extent(0), + n_unique_keys, + d_sums.data_handle(), + handle.get_stream()); + } +} + +/** @} */ // end of group reduce_rows_by_key + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/rsvd.cuh b/cpp/include/raft/linalg/rsvd.cuh index f5eaba7526..e465ee6fa2 100644 --- a/cpp/include/raft/linalg/rsvd.cuh +++ b/cpp/include/raft/linalg/rsvd.cuh @@ -18,7 +18,9 @@ #pragma once -#include +#include "detail/rsvd.cuh" + +#include namespace raft { namespace linalg { @@ -137,6 +139,653 @@ void rsvdPerc(const raft::handle_t& handle, stream); } +/** + * @defgroup rsvd Randomized Singular Value Decomposition + * @{ + */ + +/** + * @brief randomized singular value decomposition (RSVD) on a column major + * rectangular matrix using QR decomposition, by specifying no. of PCs and + * upsamples directly + * @param[in] handle raft::handle_t + * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) + * @param[out] S_vec singular values raft::device_vector_view of shape (K) + * @param[in] p no. of upsamples + * @param[out] U optional left singular values of raft::device_matrix_view with layout + * raft::col_major + * @param[out] V optional right singular values of raft::device_matrix_view with layout + * raft::col_major + */ +template +void rsvd_fixed_rank( + const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + IndexType p, + std::optional> U = std::nullopt, + std::optional> V = std::nullopt) +{ + if (U) { + RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); + RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), + "Number of columns in U should be equal to length of S"); + } + if (V) { + RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); + RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), + "Number of rows in V should be equal to length of S"); + } + + rsvdFixedRank(handle, + const_cast(M.data_handle()), + M.extent(0), + M.extent(1), + S_vec.data_handle(), + U.value().data_handle(), + V.value().data_handle(), + S_vec.extent(0), + p, + false, + U.has_value(), + V.has_value(), + false, + static_cast(0), + 0, + handle.get_stream()); +} + +/** + * @brief Overload of `rsvd_fixed_rank` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `rsvd_fixed_rank`. + */ +template +void rsvd_fixed_rank(const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + IndexType p, + ValueType tol, + int max_sweeps, + UType&& U, + VType&& V) +{ + std::optional> U_optional = + std::forward(U); + std::optional> V_optional = + std::forward(V); + + rsvd_fixed_rank(handle, M, S_vec, p, tol, max_sweeps, U_optional, V_optional); +} + +/** + * @brief randomized singular value decomposition (RSVD) on a column major + * rectangular matrix using symmetric Eigen decomposition, by specifying no. of PCs and + * upsamples directly. The rectangular input matrix is made square and symmetric using B @ B^T + * @param[in] handle raft::handle_t + * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) + * @param[out] S_vec singular values raft::device_vector_view of shape (K) + * @param[in] p no. of upsamples + * @param[out] U optional left singular values of raft::device_matrix_view with layout + * raft::col_major + * @param[out] V optional right singular values of raft::device_matrix_view with layout + * raft::col_major + */ +template +void rsvd_fixed_rank_symmetric( + const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + IndexType p, + std::optional> U = std::nullopt, + std::optional> V = std::nullopt) +{ + if (U) { + RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); + RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), + "Number of columns in U should be equal to length of S"); + } + if (V) { + RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); + RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), + "Number of rows in V should be equal to length of S"); + } + + rsvdFixedRank(handle, + const_cast(M.data_handle()), + M.extent(0), + M.extent(1), + S_vec.data_handle(), + U.value().data_handle(), + V.value().data_handle(), + S_vec.extent(0), + p, + true, + U.has_value(), + V.has_value(), + false, + static_cast(0), + 0, + handle.get_stream()); +} + +/** + * @brief Overload of `rsvd_fixed_rank_symmetric` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `rsvd_fixed_rank_symmetric`. + */ +template +void rsvd_fixed_rank_symmetric( + const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + IndexType p, + ValueType tol, + int max_sweeps, + UType&& U, + VType&& V) +{ + std::optional> U_optional = + std::forward(U); + std::optional> V_optional = + std::forward(V); + + rsvd_fixed_rank_symmetric(handle, M, S_vec, p, tol, max_sweeps, U_optional, V_optional); +} + +/** + * @brief randomized singular value decomposition (RSVD) on a column major + * rectangular matrix using Jacobi method, by specifying no. of PCs and + * upsamples directly + * @param[in] handle raft::handle_t + * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) + * @param[out] S_vec singular values raft::device_vector_view of shape (K) + * @param[in] p no. of upsamples + * @param[in] tol tolerance for Jacobi-based solvers + * @param[in] max_sweeps maximum number of sweeps for Jacobi-based solvers + * @param[out] U optional left singular values of raft::device_matrix_view with layout + * raft::col_major + * @param[out] V optional right singular values of raft::device_matrix_view with layout + * raft::col_major + */ +template +void rsvd_fixed_rank_jacobi( + const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + IndexType p, + ValueType tol, + int max_sweeps, + std::optional> U = std::nullopt, + std::optional> V = std::nullopt) +{ + if (U) { + RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); + RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), + "Number of columns in U should be equal to length of S"); + } + if (V) { + RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); + RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), + "Number of rows in V should be equal to length of S"); + } + + rsvdFixedRank(handle, + const_cast(M.data_handle()), + M.extent(0), + M.extent(1), + S_vec.data_handle(), + U.value().data_handle(), + V.value().data_handle(), + S_vec.extent(0), + p, + false, + U.has_value(), + V.has_value(), + true, + tol, + max_sweeps, + handle.get_stream()); +} + +/** + * @brief Overload of `rsvd_fixed_rank_jacobi` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `rsvd_fixed_rank_jacobi`. + */ +template +void rsvd_fixed_rank_jacobi(const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + IndexType p, + ValueType tol, + int max_sweeps, + UType&& U, + VType&& V) +{ + std::optional> U_optional = + std::forward(U); + std::optional> V_optional = + std::forward(V); + + rsvd_fixed_rank_sjacobi(handle, M, S_vec, p, tol, max_sweeps, U_optional, V_optional); +} + +/** + * @brief randomized singular value decomposition (RSVD) on a column major + * rectangular matrix using Jacobi method, by specifying no. of PCs and + * upsamples directly. The rectangular input matrix is made square and symmetric using B @ B^T + * @param[in] handle raft::handle_t + * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) + * @param[out] S_vec singular values raft::device_vector_view of shape (K) + * @param[in] p no. of upsamples + * @param[in] tol tolerance for Jacobi-based solvers + * @param[in] max_sweeps maximum number of sweeps for Jacobi-based solvers + * @param[out] U optional left singular values of raft::device_matrix_view with layout + * raft::col_major + * @param[out] V optional right singular values of raft::device_matrix_view with layout + * raft::col_major + */ +template +void rsvd_fixed_rank_symmetric_jacobi( + const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + IndexType p, + ValueType tol, + int max_sweeps, + std::optional> U = std::nullopt, + std::optional> V = std::nullopt) +{ + if (U) { + RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); + RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), + "Number of columns in U should be equal to length of S"); + } + if (V) { + RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); + RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), + "Number of rows in V should be equal to length of S"); + } + + rsvdFixedRank(handle, + const_cast(M.data_handle()), + M.extent(0), + M.extent(1), + S_vec.data_handle(), + U.value().data_handle(), + V.value().data_handle(), + S_vec.extent(0), + p, + true, + U.has_value(), + V.has_value(), + true, + tol, + max_sweeps, + handle.get_stream()); +} + +/** + * @brief Overload of `rsvd_fixed_rank_symmetric_jacobi` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `rsvd_fixed_rank_symmetric_jacobi`. + */ +template +void rsvd_fixed_rank_symmetric_jacobi( + const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + IndexType p, + ValueType tol, + int max_sweeps, + UType&& U, + VType&& V) +{ + std::optional> U_optional = + std::forward(U); + std::optional> V_optional = + std::forward(V); + + rsvd_fixed_rank_symmetric_jacobi(handle, M, S_vec, p, tol, max_sweeps, U_optional, V_optional); +} + +/** + * @brief randomized singular value decomposition (RSVD) on a column major + * rectangular matrix using QR decomposition, by specifying the PC and upsampling + * ratio + * @param[in] handle raft::handle_t + * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) + * @param[out] S_vec singular values raft::device_vector_view of shape (K) + * @param[in] PC_perc percentage of singular values to be computed + * @param[in] UpS_perc upsampling percentage + * @param[out] U optional left singular values of raft::device_matrix_view with layout + * raft::col_major + * @param[out] V optional right singular values of raft::device_matrix_view with layout + * raft::col_major + */ +template +void rsvd_perc( + const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + ValueType PC_perc, + ValueType UpS_perc, + std::optional> U = std::nullopt, + std::optional> V = std::nullopt) +{ + if (U) { + RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); + RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), + "Number of columns in U should be equal to length of S"); + } + if (V) { + RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); + RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), + "Number of rows in V should be equal to length of S"); + } + + rsvdPerc(handle, + const_cast(M.data_handle()), + M.extent(0), + M.extent(1), + S_vec.data_handle(), + U.value().data_handle(), + V.value().data_handle(), + PC_perc, + UpS_perc, + false, + U.has_value(), + V.has_value(), + false, + static_cast(0), + 0, + handle.get_stream()); +} + +/** + * @brief Overload of `rsvd_perc` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `rsvd_perc`. + */ +template +void rsvd_perc(const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + ValueType PC_perc, + ValueType UpS_perc, + ValueType tol, + int max_sweeps, + UType&& U, + VType&& V) +{ + std::optional> U_optional = + std::forward(U); + std::optional> V_optional = + std::forward(V); + + rsvd_perc(handle, M, S_vec, PC_perc, UpS_perc, tol, max_sweeps, U_optional, V_optional); +} + +/** + * @brief randomized singular value decomposition (RSVD) on a column major + * rectangular matrix using symmetric Eigen decomposition, by specifying the PC and upsampling + * ratio. The rectangular input matrix is made square and symmetric using B @ B^T + * @param[in] handle raft::handle_t + * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) + * @param[out] S_vec singular values raft::device_vector_view of shape (K) + * @param[in] PC_perc percentage of singular values to be computed + * @param[in] UpS_perc upsampling percentage + * @param[out] U optional left singular values of raft::device_matrix_view with layout + * raft::col_major + * @param[out] V optional right singular values of raft::device_matrix_view with layout + * raft::col_major + */ +template +void rsvd_perc_symmetric( + const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + ValueType PC_perc, + ValueType UpS_perc, + std::optional> U = std::nullopt, + std::optional> V = std::nullopt) +{ + if (U) { + RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); + RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), + "Number of columns in U should be equal to length of S"); + } + if (V) { + RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); + RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), + "Number of rows in V should be equal to length of S"); + } + + rsvdPerc(handle, + const_cast(M.data_handle()), + M.extent(0), + M.extent(1), + S_vec.data_handle(), + U.value().data_handle(), + V.value().data_handle(), + PC_perc, + UpS_perc, + true, + U.has_value(), + V.has_value(), + false, + static_cast(0), + 0, + handle.get_stream()); +} + +/** + * @brief Overload of `rsvd_perc_symmetric` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `rsvd_perc_symmetric`. + */ +template +void rsvd_perc_symmetric(const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + ValueType PC_perc, + ValueType UpS_perc, + ValueType tol, + int max_sweeps, + UType&& U, + VType&& V) +{ + std::optional> U_optional = + std::forward(U); + std::optional> V_optional = + std::forward(V); + + rsvd_perc_symmetric(handle, M, S_vec, PC_perc, UpS_perc, tol, max_sweeps, U_optional, V_optional); +} + +/** + * @brief randomized singular value decomposition (RSVD) on a column major + * rectangular matrix using Jacobi method, by specifying the PC and upsampling + * ratio + * @param[in] handle raft::handle_t + * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) + * @param[out] S_vec singular values raft::device_vector_view of shape (K) + * @param[in] PC_perc percentage of singular values to be computed + * @param[in] UpS_perc upsampling percentage + * @param[in] tol tolerance for Jacobi-based solvers + * @param[in] max_sweeps maximum number of sweeps for Jacobi-based solvers + * @param[out] U optional left singular values of raft::device_matrix_view with layout + * raft::col_major + * @param[out] V optional right singular values of raft::device_matrix_view with layout + * raft::col_major + */ +template +void rsvd_perc_jacobi( + const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + ValueType PC_perc, + ValueType UpS_perc, + ValueType tol, + int max_sweeps, + std::optional> U = std::nullopt, + std::optional> V = std::nullopt) +{ + if (U) { + RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); + RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), + "Number of columns in U should be equal to length of S"); + } + if (V) { + RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); + RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), + "Number of rows in V should be equal to length of S"); + } + + rsvdPerc(handle, + const_cast(M.data_handle()), + M.extent(0), + M.extent(1), + S_vec.data_handle(), + U.value().data_handle(), + V.value().data_handle(), + PC_perc, + UpS_perc, + false, + U.has_value(), + V.has_value(), + true, + tol, + max_sweeps, + handle.get_stream()); +} + +/** + * @brief Overload of `rsvd_perc_jacobi` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `rsvd_perc_jacobi`. + */ +template +void rsvd_perc_jacobi(const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + ValueType PC_perc, + ValueType UpS_perc, + ValueType tol, + int max_sweeps, + UType&& U, + VType&& V) +{ + std::optional> U_optional = + std::forward(U); + std::optional> V_optional = + std::forward(V); + + rsvd_perc_jacobi(handle, M, S_vec, PC_perc, UpS_perc, tol, max_sweeps, U_optional, V_optional); +} + +/** + * @brief randomized singular value decomposition (RSVD) on a column major + * rectangular matrix using Jacobi method, by specifying the PC and upsampling + * ratio. The rectangular input matrix is made square and symmetric using B @ B^T + * @param[in] handle raft::handle_t + * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) + * @param[out] S_vec singular values raft::device_vector_view of shape (K) + * @param[in] PC_perc percentage of singular values to be computed + * @param[in] UpS_perc upsampling percentage + * @param[in] tol tolerance for Jacobi-based solvers + * @param[in] max_sweeps maximum number of sweeps for Jacobi-based solvers + * @param[out] U optional left singular values of raft::device_matrix_view with layout + * raft::col_major + * @param[out] V optional right singular values of raft::device_matrix_view with layout + * raft::col_major + */ +template +void rsvd_perc_symmetric_jacobi( + const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + ValueType PC_perc, + ValueType UpS_perc, + ValueType tol, + int max_sweeps, + std::optional> U = std::nullopt, + std::optional> V = std::nullopt) +{ + if (U) { + RAFT_EXPECTS(M.extent(0) == U.value().extent(0), "Number of rows in M should be equal to U"); + RAFT_EXPECTS(S_vec.extent(0) == U.value().extent(1), + "Number of columns in U should be equal to length of S"); + } + if (V) { + RAFT_EXPECTS(M.extent(1) == V.value().extent(1), "Number of columns in M should be equal to V"); + RAFT_EXPECTS(S_vec.extent(0) == V.value().extent(0), + "Number of rows in V should be equal to length of S"); + } + + rsvdPerc(handle, + const_cast(M.data_handle()), + M.extent(0), + M.extent(1), + S_vec.data_handle(), + U.value().data_handle(), + V.value().data_handle(), + PC_perc, + UpS_perc, + true, + U.has_value(), + V.has_value(), + true, + tol, + max_sweeps, + handle.get_stream()); +} + +/** + * @brief Overload of `rsvd_perc_symmetric_jacobi` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `rsvd_perc_symmetric_jacobi`. + */ +template +void rsvd_perc_symmetric_jacobi( + const raft::handle_t& handle, + raft::device_matrix_view M, + raft::device_vector_view S_vec, + ValueType PC_perc, + ValueType UpS_perc, + ValueType tol, + int max_sweeps, + UType&& U, + VType&& V) +{ + std::optional> U_optional = + std::forward(U); + std::optional> V_optional = + std::forward(V); + + rsvd_perc_symmetric_jacobi( + handle, M, S_vec, PC_perc, UpS_perc, tol, max_sweeps, U_optional, V_optional); +} + +/** @} */ // end of group rsvd + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/sqrt.cuh b/cpp/include/raft/linalg/sqrt.cuh index c81e38eace..2951285c3a 100644 --- a/cpp/include/raft/linalg/sqrt.cuh +++ b/cpp/include/raft/linalg/sqrt.cuh @@ -18,6 +18,7 @@ #pragma once +#include #include #include @@ -34,14 +35,55 @@ namespace linalg { * @param stream cuda stream where to launch work * @{ */ -template -void sqrt(math_t* out, const math_t* in, IdxType len, cudaStream_t stream) +template +void sqrt(out_t* out, const in_t* in, IdxType len, cudaStream_t stream) { raft::linalg::unaryOp( - out, in, len, [] __device__(math_t in) { return raft::mySqrt(in); }, stream); + out, in, len, [] __device__(in_t in) { return raft::mySqrt(in); }, stream); } /** @} */ +/** + * @defgroup sqrt Sqrt Arithmetic + * @{ + */ + +/** + * @brief Elementwise sqrt operation + * @tparam InType Input Type raft::device_mdspan + * @tparam OutType Output Type raft::device_mdspan + * @param[in] handle raft::handle_t + * @param[in] in Input + * @param[out] out Output + */ +template , + typename = raft::enable_if_output_device_mdspan> +void sqrt(const raft::handle_t& handle, InType in, OutType out) +{ + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input 1 must be contiguous"); + RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Inputs"); + + if (out.size() <= std::numeric_limits::max()) { + sqrt(out.data_handle(), + in.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } else { + sqrt(out.data_handle(), + in.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } +} + +/** @} */ // end of group add + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index 941e64dcb1..6927269821 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -21,6 +21,9 @@ #include "detail/strided_reduction.cuh" +#include +#include + namespace raft { namespace linalg { @@ -71,6 +74,90 @@ void stridedReduction(OutType* dots, detail::stridedReduction(dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } +/** + * @defgroup strided_reduction Strided Memory Access Reductions + * For reducing along rows for row-major and along columns for column-major + * @{ + */ + +/** + * @brief Compute reduction of the input matrix along the strided dimension + * This API is to be used when the desired reduction is NOT along the dimension + * of the memory layout. For example, a row-major matrix will be reduced + * along the rows whereas a column-major matrix will be reduced along + * the columns. + * + * @tparam InValueType the input data-type of underlying raft::matrix_view + * @tparam LayoutPolicy The layout of Input/Output (row or col major) + * @tparam OutValueType the output data-type of underlying raft::matrix_view and reduction + * @tparam IndexType Integer type used to for addressing + * @tparam MainLambda Unary lambda applied while acculumation (eg: L1 or L2 norm) + * It must be a 'callable' supporting the following input and output: + *
OutType (*MainLambda)(InType, IdxType);
+ * @tparam ReduceLambda Binary lambda applied for reduction (eg: addition(+) for L2 norm) + * It must be a 'callable' supporting the following input and output: + *
OutType (*ReduceLambda)(OutType);
+ * @tparam FinalLambda the final lambda applied before STG (eg: Sqrt for L2 norm) + * It must be a 'callable' supporting the following input and output: + *
OutType (*FinalLambda)(OutType);
+ * @param[in] handle raft::handle_t + * @param[in] data Input of type raft::device_matrix_view + * @param[out] dots Output of type raft::device_matrix_view + * @param[in] init initial value to use for the reduction + * @param[in] main_op fused elementwise operation to apply before reduction + * @param[in] reduce_op fused binary reduction operation + * @param[in] final_op fused elementwise operation to apply before storing results + * @param[in] inplace reduction result added inplace or overwrites old values? + */ +template , + typename ReduceLambda = raft::Sum, + typename FinalLambda = raft::Nop> +void strided_reduction(const raft::handle_t& handle, + raft::device_matrix_view data, + raft::device_vector_view dots, + OutValueType init, + bool inplace = false, + MainLambda main_op = raft::Nop(), + ReduceLambda reduce_op = raft::Sum(), + FinalLambda final_op = raft::Nop()) +{ + if constexpr (std::is_same_v) { + RAFT_EXPECTS(static_cast(dots.size()) == data.extent(1), + "Output should be equal to number of columns in Input"); + + stridedReduction(dots.data_handle(), + data.data_handle(), + data.extent(1), + data.extent(0), + init, + handle.get_stream(), + inplace, + main_op, + reduce_op, + final_op); + } else if constexpr (std::is_same_v) { + RAFT_EXPECTS(static_cast(dots.size()) == data.extent(0), + "Output should be equal to number of rows in Input"); + + stridedReduction(dots.data_handle(), + data.data_handle(), + data.extent(0), + data.extent(1), + init, + handle.get_stream(), + inplace, + main_op, + reduce_op, + final_op); + } +} + +/** @} */ // end of group strided_reduction + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/subtract.cuh b/cpp/include/raft/linalg/subtract.cuh index 9ca36ddddf..4f81822a13 100644 --- a/cpp/include/raft/linalg/subtract.cuh +++ b/cpp/include/raft/linalg/subtract.cuh @@ -21,6 +21,10 @@ #include "detail/subtract.cuh" +#include +#include +#include + namespace raft { namespace linalg { @@ -84,6 +88,140 @@ void subtractDevScalar(math_t* outDev, detail::subtractDevScalar(outDev, inDev, singleScalarDev, len, stream); } +/** + * @defgroup sub Subtraction Arithmetic + * @{ + */ + +/** + * @brief Elementwise subtraction operation on the input buffers + * @tparam InType Input Type raft::device_mdspan + * @tparam OutType Output Type raft::device_mdspan + * @param handle raft::handle_t + * @param[in] in1 First Input + * @param[in] in2 Second Input + * @param[out] out Output + */ +template , + typename = raft::enable_if_output_device_mdspan> +void subtract(const raft::handle_t& handle, InType in1, InType in2, OutType out) +{ + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in2), "Input 2 must be contiguous"); + RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size(), + "Size mismatch between Output and Inputs"); + + if (out.size() <= std::numeric_limits::max()) { + subtract(out.data_handle(), + in1.data_handle(), + in2.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } else { + subtract(out.data_handle(), + in1.data_handle(), + in2.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } +} + +/** + * @brief Elementwise subtraction of device scalar to input + * @tparam InType Input Type raft::device_mdspan + * @tparam OutType Output Type raft::device_mdspan + * @tparam ScalarIdxType Index Type of scalar + * @param[in] handle raft::handle_t + * @param[in] in Input + * @param[out] out Output + * @param[in] scalar raft::device_scalar_view + */ +template , + typename = raft::enable_if_output_device_mdspan> +void subtract_scalar( + const raft::handle_t& handle, + InType in, + OutType out, + raft::device_scalar_view scalar) +{ + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input"); + + if (out.size() <= std::numeric_limits::max()) { + subtractDevScalar( + out.data_handle(), + in.data_handle(), + scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } else { + subtractDevScalar( + out.data_handle(), + in.data_handle(), + scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } +} + +/** + * @brief Elementwise subtraction of host scalar to input + * @tparam InType Input Type raft::device_mdspan + * @tparam OutType Output Type raft::device_mdspan + * @tparam ScalarIdxType Index Type of scalar + * @param[in] handle raft::handle_t + * @param[in] in Input + * @param[out] out Output + * @param[in] scalar raft::host_scalar_view + */ +template , + typename = raft::enable_if_output_device_mdspan> +void subtract_scalar( + const raft::handle_t& handle, + InType in, + OutType out, + raft::host_scalar_view scalar) +{ + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input"); + + if (out.size() <= std::numeric_limits::max()) { + subtractScalar(out.data_handle(), + in.data_handle(), + *scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } else { + subtractScalar(out.data_handle(), + in.data_handle(), + *scalar.data_handle(), + static_cast(out.size()), + handle.get_stream()); + } +} + +/** @} */ // end of group subtract + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/svd.cuh b/cpp/include/raft/linalg/svd.cuh index b48def90a3..0026ec1f7d 100644 --- a/cpp/include/raft/linalg/svd.cuh +++ b/cpp/include/raft/linalg/svd.cuh @@ -20,6 +20,8 @@ #include "detail/svd.cuh" +#include + namespace raft { namespace linalg { @@ -38,9 +40,6 @@ namespace linalg { * @param gen_right_vec: generate right eig vector. Not activated. * @param stream cuda stream */ -// TODO: activate gen_left_vec and gen_right_vec options -// TODO: couldn't template this function due to cusolverDnSgesvd and -// cusolverSnSgesvd. Check if there is any other way. template void svdQR(const raft::handle_t& handle, T* in, @@ -182,6 +181,219 @@ bool evaluateSVDByL2Norm(const raft::handle_t& handle, return detail::evaluateSVDByL2Norm(handle, A_d, U, S_vec, V, n_rows, n_cols, k, tol, stream); } +/** + * @defgroup svd Singular Value Decomposition + * @{ + */ + +/** + * @brief singular value decomposition (SVD) on a column major + * matrix using QR decomposition + * @param[in] handle raft::handle_t + * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) + * @param[out] sing_vals singular values raft::device_vector_view of shape (K) + * @param[out] left_sing_vecs optional left singular values of raft::device_matrix_view with layout + * raft::col_major and dimensions (m, n) + * @param[out] right_sing_vecs optional right singular values of raft::device_matrix_view with + * layout raft::col_major and dimensions (n, n) + */ +template +void svd_qr( + const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view sing_vals, + std::optional> left_sing_vecs = + std::nullopt, + std::optional> right_sing_vecs = + std::nullopt) +{ + if (left_sing_vecs) { + RAFT_EXPECTS(in.extent(0) == left_sing_vecs.value().extent(0) && + in.extent(1) == left_sing_vecs.value().extent(1), + "U should have dimensions m * n"); + } + if (right_sing_vecs) { + RAFT_EXPECTS(in.extent(1) == right_sing_vecs.value().extent(0) && + in.extent(1) == right_sing_vecs.value().extent(1), + "V should have dimensions n * n"); + } + svdQR(handle, + const_cast(in.data_handle()), + in.extent(0), + in.extent(1), + sing_vals.data_handle(), + left_sing_vecs.value().data_handle(), + right_sing_vecs.value().data_handle(), + false, + left_sing_vecs.has_value(), + right_sing_vecs.has_value(), + handle.get_stream()); +} + +/** + * @brief Overload of `svd_qr` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `svd_qr`. + */ +template +void svd_qr(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view sing_vals, + UType&& U, + VType&& V) +{ + std::optional> U_optional = + std::forward(U); + std::optional> V_optional = + std::forward(V); + + svd_qr(handle, in, sing_vals, U_optional, V_optional); +} + +/** + * @brief singular value decomposition (SVD) on a column major + * matrix using QR decomposition. Right singular vector matrix is transposed before returning + * @param[in] handle raft::handle_t + * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) + * @param[out] sing_vals singular values raft::device_vector_view of shape (K) + * @param[out] left_sing_vecs optional left singular values of raft::device_matrix_view with layout + * raft::col_major and dimensions (m, n) + * @param[out] right_sing_vecs optional right singular values of raft::device_matrix_view with + * layout raft::col_major and dimensions (n, n) + */ +template +void svd_qr_transpose_right_vec( + const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view sing_vals, + std::optional> left_sing_vecs = + std::nullopt, + std::optional> right_sing_vecs = + std::nullopt) +{ + if (left_sing_vecs) { + RAFT_EXPECTS(in.extent(0) == left_sing_vecs.value().extent(0) && + in.extent(1) == left_sing_vecs.value().extent(1), + "U should have dimensions m * n"); + } + if (right_sing_vecs) { + RAFT_EXPECTS(in.extent(1) == right_sing_vecs.value().extent(0) && + in.extent(1) == right_sing_vecs.value().extent(1), + "V should have dimensions n * n"); + } + svdQR(handle, + const_cast(in.data_handle()), + in.extent(0), + in.extent(1), + sing_vals.data_handle(), + left_sing_vecs.value().data_handle(), + right_sing_vecs.value().data_handle(), + true, + left_sing_vecs.has_value(), + right_sing_vecs.has_value(), + handle.get_stream()); +} + +/** + * @brief Overload of `svd_qr_transpose_right_vec` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `svd_qr_transpose_right_vec`. + */ +template +void svd_qr_transpose_right_vec( + const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view sing_vals, + UType&& U, + VType&& V) +{ + std::optional> U_optional = + std::forward(U); + std::optional> V_optional = + std::forward(V); + + svd_qr_transpose_right_vec(handle, in, sing_vals, U_optional, V_optional); +} + +/** + * @brief singular value decomposition (SVD) on a column major + * matrix using Eigen decomposition. A square symmetric covariance matrix is constructed for the SVD + * @param[in] handle raft::handle_t + * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) + * @param[out] S singular values raft::device_vector_view of shape (K) + * @param[out] V right singular values of raft::device_matrix_view with layout + * raft::col_major and dimensions (m, n) + * @param[out] U optional left singular values of raft::device_matrix_view with layout + * raft::col_major and dimensions (m, n) + */ +template +void svd_eig( + const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view S, + raft::device_matrix_view V, + std::optional> U = std::nullopt) +{ + if (U) { + RAFT_EXPECTS(in.extent(0) == U.value().extent(0) && in.extent(1) == U.value().extent(1), + "U should have dimensions m * n"); + } + RAFT_EXPECTS(in.extent(0) == V.extent(0) && in.extent(1) == V.extent(1), + "V should have dimensions n * n"); + svdEig(handle, + const_cast(in.data_handle()), + in.extent(0), + in.extent(1), + S.data_handle(), + U.value().data_handle(), + V.value().data_handle(), + U.has_value(), + handle.get_stream()); +} + +/** + * @brief reconstruct a matrix use left and right singular vectors and + * singular values + * @param[in] handle raft::handle_t + * @param[in] U left singular values of raft::device_matrix_view with layout + * raft::col_major and dimensions (m, k) + * @param[in] S singular values raft::device_vector_view of shape (k, k) + * @param[in] V right singular values of raft::device_matrix_view with layout + * raft::col_major and dimensions (k, n) + * @param[out] out output raft::device_matrix_view with layout raft::col_major of shape (m, n) + */ +template +void svd_reconstruction(const raft::handle_t& handle, + raft::device_matrix_view U, + raft::device_vector_view S, + raft::device_matrix_view V, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(S.extent(0) == S.extent(1), "S should be a square matrix"); + RAFT_EXPECTS(S.extent(0) == U.extent(1), + "Number of rows of S should be equal to number of columns in U"); + RAFT_EXPECTS(S.extent(1) == V.extent(0), + "Number of columns of S should be equal to number of rows in V"); + RAFT_EXPECTS(out.extent(0) == U.extent(0) && out.extent(1) == V.extent(1), + "Number of rows should be equal in out and U and number of columns should be equal " + "in out and V"); + + svdReconstruction(handle, + const_cast(U.data_handle()), + const_cast(S.data_handle()), + const_cast(V.data_handle()), + out.extent(0), + out.extent(1), + S.extent(0), + handle.get_stream()); +} + +/** @} */ // end of group svd + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/ternary_op.cuh b/cpp/include/raft/linalg/ternary_op.cuh index 158cca168d..10e91a0313 100644 --- a/cpp/include/raft/linalg/ternary_op.cuh +++ b/cpp/include/raft/linalg/ternary_op.cuh @@ -19,7 +19,11 @@ #pragma once -#include +#include "detail/ternary_op.cuh" + +#include +#include +#include namespace raft { namespace linalg { @@ -37,8 +41,8 @@ namespace linalg { * @param op the device-lambda * @param stream cuda stream where to launch work */ -template -void ternaryOp(math_t* out, +template +void ternaryOp(out_t* out, const math_t* in1, const math_t* in2, const math_t* in3, @@ -49,6 +53,64 @@ void ternaryOp(math_t* out, detail::ternaryOp(out, in1, in2, in3, len, op, stream); } +/** + * @defgroup ternary_op Element-Wise Ternary Operation + * @{ + */ + +/** + * @brief perform element-wise ternary operation on the input arrays + * @tparam InType Input Type raft::device_mdspan + * @tparam Lambda the device-lambda performing the actual operation + * @tparam OutType Output Type raft::device_mdspan + * @param[in] handle raft::handle_t + * @param[in] in1 First input + * @param[in] in2 Second input + * @param[in] in3 Third input + * @param[out] out Output + * @param[in] op the device-lambda + * @note Lambda must be a functor with the following signature: + * `OutType func(const InType& val1, const InType& val2, const InType& val3);` + */ +template , + typename = raft::enable_if_output_device_mdspan> +void ternary_op( + const raft::handle_t& handle, InType in1, InType in2, InType in3, OutType out, Lambda op) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in2), "Input 2 must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in3), "Input 3 must be contiguous"); + RAFT_EXPECTS(out.size() == in1.size() && in1.size() == in2.size() && in2.size() == in3.size(), + "Size mismatch between Output and Inputs"); + + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + if (out.size() <= std::numeric_limits::max()) { + ternaryOp(out.data_handle(), + in1.data_handle(), + in2.data_handle(), + in3.data_handle(), + out.size(), + op, + handle.get_stream()); + } else { + ternaryOp(out.data_handle(), + in1.data_handle(), + in2.data_handle(), + in3.data_handle(), + out.size(), + op, + handle.get_stream()); + } +} + +/** @} */ // end of group ternary_op + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/linalg/unary_op.cuh b/cpp/include/raft/linalg/unary_op.cuh index f2466df463..a90bda06d5 100644 --- a/cpp/include/raft/linalg/unary_op.cuh +++ b/cpp/include/raft/linalg/unary_op.cuh @@ -20,6 +20,10 @@ #include "detail/unary_op.cuh" +#include +#include +#include + namespace raft { namespace linalg { @@ -71,6 +75,75 @@ void writeOnlyUnaryOp(OutType* out, IdxType len, Lambda op, cudaStream_t stream) detail::writeOnlyUnaryOpCaller(out, len, op, stream); } +/** + * @defgroup unary_op Element-Wise Unary Operations + * @{ + */ + +/** + * @brief perform element-wise binary operation on the input arrays + * @tparam InType Input Type raft::device_mdspan + * @tparam Lambda the device-lambda performing the actual operation + * @tparam OutType Output Type raft::device_mdspan + * @param[in] handle raft::handle_t + * @param[in] in Input + * @param[out] out Output + * @param[in] op the device-lambda + * @note Lambda must be a functor with the following signature: + * `InType func(const InType& val);` + */ +template , + typename = raft::enable_if_output_device_mdspan> +void unary_op(const raft::handle_t& handle, InType in, OutType out, Lambda op) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + RAFT_EXPECTS(out.size() == in.size(), "Size mismatch between Output and Input"); + + using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; + + if (out.size() <= std::numeric_limits::max()) { + unaryOp( + out.data_handle(), in.data_handle(), out.size(), op, handle.get_stream()); + } else { + unaryOp( + out.data_handle(), in.data_handle(), out.size(), op, handle.get_stream()); + } +} + +/** + * @brief perform element-wise binary operation on the input arrays + * This function does not read from the input + * @tparam InType Input Type raft::device_mdspan + * @tparam Lambda the device-lambda performing the actual operation + * @param[in] handle raft::handle_t + * @param[inout] in Input/Output + * @param[in] op the device-lambda + * @note Lambda must be a functor with the following signature: + * `InType func(const InType& val);` + */ +template > +void write_only_unary_op(const raft::handle_t& handle, InType in, Lambda op) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + + using in_value_t = typename InType::value_type; + + if (in.size() <= std::numeric_limits::max()) { + writeOnlyUnaryOp( + in.data_handle(), in.size(), op, handle.get_stream()); + } else { + writeOnlyUnaryOp( + in.data_handle(), in.size(), op, handle.get_stream()); + } +} + +/** @} */ // end of group unary_op + }; // end namespace linalg }; // end namespace raft diff --git a/cpp/include/raft/stats/detail/kl_divergence.cuh b/cpp/include/raft/stats/detail/kl_divergence.cuh index d396d95206..6d9cc41790 100644 --- a/cpp/include/raft/stats/detail/kl_divergence.cuh +++ b/cpp/include/raft/stats/detail/kl_divergence.cuh @@ -67,7 +67,7 @@ DataT kl_divergence(const DataT* modelPDF, const DataT* candidatePDF, int size, rmm::device_scalar d_KLDVal(stream); RAFT_CUDA_TRY(cudaMemsetAsync(d_KLDVal.data(), 0, sizeof(DataT), stream)); - raft::linalg::mapThenSumReduce, 256, const DataT*>( + raft::linalg::mapThenSumReduce, size_t, 256, const DataT*>( d_KLDVal.data(), (size_t)size, KLDOp(), stream, modelPDF, candidatePDF); DataT h_KLDVal; diff --git a/cpp/include/raft/util/input_validation.hpp b/cpp/include/raft/util/input_validation.hpp new file mode 100644 index 0000000000..b34843f5e8 --- /dev/null +++ b/cpp/include/raft/util/input_validation.hpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft { + +template +constexpr bool is_row_or_column_major(mdspan m) +{ + return false; +} + +template +constexpr bool is_row_or_column_major(mdspan m) +{ + return true; +} + +template +constexpr bool is_row_or_column_major(mdspan m) +{ + return true; +} + +template +constexpr bool is_row_or_column_major(mdspan m) +{ + return m.is_exhaustive(); +} + +}; // end namespace raft \ No newline at end of file diff --git a/cpp/test/linalg/add.cu b/cpp/test/linalg/add.cu index d9a90321e1..c73791086b 100644 --- a/cpp/test/linalg/add.cu +++ b/cpp/test/linalg/add.cu @@ -46,7 +46,12 @@ class AddTest : public ::testing::TestWithParam> { uniform(handle, r, in1.data(), len, InT(-1.0), InT(1.0)); uniform(handle, r, in2.data(), len, InT(-1.0), InT(1.0)); naiveAddElem(out_ref.data(), in1.data(), in2.data(), len, stream); - add(out.data(), in1.data(), in2.data(), len, stream); + + auto out_view = raft::make_device_vector_view(out.data(), out.size()); + auto in1_view = raft::make_device_vector_view(in1.data(), in1.size()); + auto in2_view = raft::make_device_vector_view(in2.data(), in2.size()); + + add(handle, in1_view, in2_view, out_view); handle.sync_stream(stream); } diff --git a/cpp/test/linalg/binary_op.cu b/cpp/test/linalg/binary_op.cu index 25383c5ca1..b92fa09427 100644 --- a/cpp/test/linalg/binary_op.cu +++ b/cpp/test/linalg/binary_op.cu @@ -30,10 +30,14 @@ namespace linalg { // within its class template void binaryOpLaunch( - OutType* out, const InType* in1, const InType* in2, IdxType len, cudaStream_t stream) + const raft::handle_t& handle, OutType* out, const InType* in1, const InType* in2, IdxType len) { - binaryOp( - out, in1, in2, len, [] __device__(InType a, InType b) { return a + b; }, stream); + auto out_view = raft::make_device_vector_view(out, len); + auto in1_view = raft::make_device_vector_view(in1, len); + auto in2_view = raft::make_device_vector_view(in2, len); + + binary_op( + handle, in1_view, in2_view, out_view, [] __device__(InType a, InType b) { return a + b; }); } template @@ -57,7 +61,7 @@ class BinaryOpTest : public ::testing::TestWithParamtestR1Update(); } TYPED_TEST(CholeskyR1Test, throwError) { this->testR1Error(); } }; // namespace linalg -}; // namespace raft +}; // namespace raft \ No newline at end of file diff --git a/cpp/test/linalg/coalesced_reduction.cu b/cpp/test/linalg/coalesced_reduction.cu index 8e28b35cef..cc2acef565 100644 --- a/cpp/test/linalg/coalesced_reduction.cu +++ b/cpp/test/linalg/coalesced_reduction.cu @@ -43,10 +43,12 @@ template // within its class template void coalescedReductionLaunch( - T* dots, const T* data, int cols, int rows, cudaStream_t stream, bool inplace = false) + const raft::handle_t& handle, T* dots, const T* data, int cols, int rows, bool inplace = false) { - coalescedReduction( - dots, data, cols, rows, (T)0, stream, inplace, [] __device__(T in, int i) { return in * in; }); + auto dots_view = raft::make_device_vector_view(dots, rows); + auto data_view = raft::make_device_matrix_view(data, rows, cols); + coalesced_reduction( + handle, data_view, dots_view, (T)0, inplace, [] __device__(T in, int i) { return in * in; }); } template @@ -71,9 +73,9 @@ class coalescedReductionTest : public ::testing::TestWithParam(in.data(), len); + auto scalar_view = raft::make_host_scalar_view(¶ms.scalar); + divide_scalar(handle, in_view, out_view, scalar_view); handle.sync_stream(stream); } diff --git a/cpp/test/linalg/eig.cu b/cpp/test/linalg/eig.cu index e05cb8f5fd..a913b14fcb 100644 --- a/cpp/test/linalg/eig.cu +++ b/cpp/test/linalg/eig.cu @@ -93,47 +93,50 @@ class EigTest : public ::testing::TestWithParam> { raft::update_device(eig_vectors_ref.data(), eig_vectors_ref_h, len, stream); raft::update_device(eig_vals_ref.data(), eig_vals_ref_h, params.n_col, stream); - eigDC(handle, - cov_matrix.data(), - params.n_row, - params.n_col, - eig_vectors.data(), - eig_vals.data(), - stream); + auto cov_matrix_view = raft::make_device_matrix_view( + cov_matrix.data(), params.n_row, params.n_col); + auto eig_vectors_view = raft::make_device_matrix_view( + eig_vectors.data(), params.n_row, params.n_col); + auto eig_vals_view = + raft::make_device_vector_view(eig_vals.data(), params.n_row); + + auto eig_vectors_jacobi_view = raft::make_device_matrix_view( + eig_vectors_jacobi.data(), params.n_row, params.n_col); + auto eig_vals_jacobi_view = + raft::make_device_vector_view(eig_vals_jacobi.data(), params.n_row); + + eig_dc(handle, cov_matrix_view, eig_vectors_view, eig_vals_view); T tol = 1.e-7; int sweeps = 15; - eigJacobi(handle, - cov_matrix.data(), - params.n_row, - params.n_col, - eig_vectors_jacobi.data(), - eig_vals_jacobi.data(), - stream, - tol, - sweeps); + eig_jacobi(handle, cov_matrix_view, eig_vectors_jacobi_view, eig_vals_jacobi_view, tol, sweeps); // test code for comparing two methods len = params.n * params.n; uniform(handle, r, cov_matrix_large.data(), len, T(-1.0), T(1.0)); - eigDC(handle, - cov_matrix_large.data(), - params.n, - params.n, - eig_vectors_large.data(), - eig_vals_large.data(), - stream); - eigJacobi(handle, - cov_matrix_large.data(), - params.n, - params.n, - eig_vectors_jacobi_large.data(), - eig_vals_jacobi_large.data(), - stream, - tol, - sweeps); + auto cov_matrix_large_view = + raft::make_device_matrix_view( + cov_matrix_large.data(), params.n, params.n); + auto eig_vectors_large_view = raft::make_device_matrix_view( + eig_vectors_large.data(), params.n, params.n); + auto eig_vals_large_view = + raft::make_device_vector_view(eig_vals_large.data(), params.n); + + auto eig_vectors_jacobi_large_view = + raft::make_device_matrix_view( + eig_vectors_jacobi_large.data(), params.n, params.n); + auto eig_vals_jacobi_large_view = + raft::make_device_vector_view(eig_vals_jacobi_large.data(), params.n); + + eig_dc(handle, cov_matrix_large_view, eig_vectors_large_view, eig_vals_large_view); + eig_jacobi(handle, + cov_matrix_large_view, + eig_vectors_jacobi_large_view, + eig_vals_jacobi_large_view, + tol, + sweeps); handle.sync_stream(stream); } diff --git a/cpp/test/linalg/eig_sel.cu b/cpp/test/linalg/eig_sel.cu index cc1dd589d0..9d57c4fa0a 100644 --- a/cpp/test/linalg/eig_sel.cu +++ b/cpp/test/linalg/eig_sel.cu @@ -80,18 +80,22 @@ class EigSelTest : public ::testing::TestWithParam> { raft::update_device( eig_vectors_ref.data(), eig_vectors_ref_h, params.n_eigen_vals * params.n, stream); - raft::update_device(eig_vals_ref.data(), eig_vals_ref_h, params.n, stream); - - raft::linalg::eigSelDC(handle, - cov_matrix.data(), - params.n, - params.n, - params.n_eigen_vals, - eig_vectors.data(), - eig_vals.data(), - EigVecMemUsage::OVERWRITE_INPUT, - stream); - handle.sync_stream(stream); + raft::update_device(eig_vals_ref.data(), eig_vals_ref_h, params.n_eigen_vals, stream); + + auto cov_matrix_view = raft::make_device_matrix_view( + cov_matrix.data(), params.n, params.n); + auto eig_vectors_view = raft::make_device_matrix_view( + eig_vectors.data(), params.n_eigen_vals, params.n); + auto eig_vals_view = + raft::make_device_vector_view(eig_vals.data(), params.n_eigen_vals); + + raft::linalg::eig_dc_selective(handle, + cov_matrix_view, + eig_vectors_view, + eig_vals_view, + static_cast(params.n_eigen_vals), + EigVecMemUsage::OVERWRITE_INPUT); + handle.sync_stream(); } protected: diff --git a/cpp/test/linalg/gemm_layout.cu b/cpp/test/linalg/gemm_layout.cu index 4b05004ccf..dbe10ab4cc 100644 --- a/cpp/test/linalg/gemm_layout.cu +++ b/cpp/test/linalg/gemm_layout.cu @@ -94,17 +94,35 @@ class GemmLayoutTest : public ::testing::TestWithParam> { naiveGemm<<>>( refZ, X, Y, params.M, params.N, params.K, params.zLayout, params.xLayout, params.yLayout); - gemm(handle, - Z, - X, - Y, - params.M, - params.N, - params.K, - params.zLayout, - params.xLayout, - params.yLayout, - stream); + auto x_view_row_major = raft::make_device_matrix_view(X, params.M, params.K); + auto y_view_row_major = raft::make_device_matrix_view(Y, params.K, params.N); + auto z_view_row_major = raft::make_device_matrix_view(Z, params.M, params.N); + + auto x_view_col_major = + raft::make_device_matrix_view(X, params.M, params.K); + auto y_view_col_major = + raft::make_device_matrix_view(Y, params.K, params.N); + auto z_view_col_major = + raft::make_device_matrix_view(Z, params.M, params.N); + + if (params.xLayout && params.yLayout && params.zLayout) { + gemm(handle, x_view_col_major, y_view_col_major, z_view_col_major); + } else if (params.xLayout && params.yLayout && !params.zLayout) { + gemm(handle, x_view_col_major, y_view_col_major, z_view_row_major); + } else if (params.xLayout && !params.yLayout && params.zLayout) { + gemm(handle, x_view_col_major, y_view_row_major, z_view_col_major); + } else if (!params.xLayout && params.yLayout && params.zLayout) { + gemm(handle, x_view_row_major, y_view_col_major, z_view_col_major); + } else if (params.xLayout && !params.yLayout && !params.zLayout) { + gemm(handle, x_view_col_major, y_view_row_major, z_view_row_major); + } else if (!params.xLayout && params.yLayout && !params.zLayout) { + gemm(handle, x_view_row_major, y_view_col_major, z_view_row_major); + } else if (!params.xLayout && !params.yLayout && params.zLayout) { + gemm(handle, x_view_row_major, y_view_row_major, z_view_col_major); + } else if (!params.xLayout && !params.yLayout && !params.zLayout) { + gemm(handle, x_view_row_major, y_view_row_major, z_view_row_major); + } + handle.sync_stream(); } diff --git a/cpp/test/linalg/gemv.cu b/cpp/test/linalg/gemv.cu index f4c437bdfc..2bd9abc200 100644 --- a/cpp/test/linalg/gemv.cu +++ b/cpp/test/linalg/gemv.cu @@ -109,15 +109,20 @@ class GemvTest : public ::testing::TestWithParam> { naiveGemv<<>>( refy.data(), A.data(), x.data(), params.n_rows, params.n_cols, params.lda, params.trans_a); - gemv(handle, - A.data(), - params.n_rows, - params.n_cols, - params.lda, - x.data(), - y.data(), - params.trans_a, - stream); + auto A_row_major = + raft::make_device_matrix_view(A.data(), params.n_rows, params.n_cols); + auto A_col_major = raft::make_device_matrix_view( + A.data(), params.n_rows, params.n_cols); + + auto x_view = raft::make_device_vector_view(x.data(), xElems); + auto y_view = raft::make_device_vector_view(y.data(), yElems); + + if (params.trans_a) { + gemv(handle, A_row_major, x_view, y_view); + } else { + gemv(handle, A_col_major, x_view, y_view); + } + handle.sync_stream(); } @@ -127,20 +132,14 @@ class GemvTest : public ::testing::TestWithParam> { const std::vector> inputsf = {{80, 70, 80, true, 76433ULL}, {80, 100, 80, true, 426646ULL}, {20, 100, 20, true, 37703ULL}, - {100, 60, 200, true, 538004ULL}, - {50, 10, 60, false, 73012ULL}, {90, 90, 90, false, 538147ULL}, - {30, 100, 30, false, 412352ULL}, - {40, 80, 100, false, 297941ULL}}; + {30, 100, 30, false, 412352ULL}}; const std::vector> inputsd = {{10, 70, 10, true, 535648ULL}, {30, 30, 30, true, 956681ULL}, {70, 80, 70, true, 875083ULL}, - {80, 90, 200, true, 50744ULL}, {90, 90, 90, false, 506321ULL}, - {40, 100, 70, false, 638418ULL}, - {80, 50, 80, false, 701529ULL}, - {50, 80, 60, false, 893038ULL}}; + {80, 50, 80, false, 701529ULL}}; typedef GemvTest GemvTestF; TEST_P(GemvTestF, Result) diff --git a/cpp/test/linalg/map.cu b/cpp/test/linalg/map.cu index 6fa26456e3..95a2aff130 100644 --- a/cpp/test/linalg/map.cu +++ b/cpp/test/linalg/map.cu @@ -33,12 +33,14 @@ void mapLaunch(OutType* out, IdxType len, cudaStream_t stream) { + raft::handle_t handle{stream}; + auto out_view = raft::make_device_vector_view(out, len); + auto in1_view = raft::make_device_vector_view(in1, len); map( - out, - len, + handle, + in1_view, + out_view, [=] __device__(InType a, InType b, InType c) { return a + b + c + scalar; }, - stream, - in1, in2, in3); } diff --git a/cpp/test/linalg/map_then_reduce.cu b/cpp/test/linalg/map_then_reduce.cu index 170962006f..adf784f601 100644 --- a/cpp/test/linalg/map_then_reduce.cu +++ b/cpp/test/linalg/map_then_reduce.cu @@ -17,6 +17,7 @@ #include "../test_utils.h" #include #include +#include #include #include #include @@ -149,19 +150,23 @@ class MapGenericReduceTest : public ::testing::Test { void testMin() { - auto op = [] __device__(InType in) { return in; }; - const OutType neutral = std::numeric_limits::max(); - mapThenReduce( - output.data(), input.size(), neutral, op, cub::Min(), handle.get_stream(), input.data()); + auto op = [] __device__(InType in) { return in; }; + OutType neutral = std::numeric_limits::max(); + auto output_view = raft::make_device_scalar_view(output.data()); + auto input_view = raft::make_device_vector_view( + input.data(), static_cast(input.size())); + map_reduce(handle, input_view, output_view, neutral, op, cub::Min()); EXPECT_TRUE(raft::devArrMatch( OutType(1), output.data(), 1, raft::Compare(), handle.get_stream())); } void testMax() { - auto op = [] __device__(InType in) { return in; }; - const OutType neutral = std::numeric_limits::min(); - mapThenReduce( - output.data(), input.size(), neutral, op, cub::Max(), handle.get_stream(), input.data()); + auto op = [] __device__(InType in) { return in; }; + OutType neutral = std::numeric_limits::min(); + auto output_view = raft::make_device_scalar_view(output.data()); + auto input_view = raft::make_device_vector_view( + input.data(), static_cast(input.size())); + map_reduce(handle, input_view, output_view, neutral, op, cub::Max()); EXPECT_TRUE(raft::devArrMatch( OutType(5), output.data(), 1, raft::Compare(), handle.get_stream())); } diff --git a/cpp/test/linalg/matrix_vector_op.cu b/cpp/test/linalg/matrix_vector_op.cu index 74ba250f86..2023ce4121 100644 --- a/cpp/test/linalg/matrix_vector_op.cu +++ b/cpp/test/linalg/matrix_vector_op.cu @@ -41,7 +41,8 @@ template // for an extended __device__ lambda cannot have private or protected access // within its class template -void matrixVectorOpLaunch(T* out, +void matrixVectorOpLaunch(const raft::handle_t& handle, + T* out, const T* in, const T* vec1, const T* vec2, @@ -49,32 +50,50 @@ void matrixVectorOpLaunch(T* out, IdxType N, bool rowMajor, bool bcastAlongRows, - bool useTwoVectors, - cudaStream_t stream) + bool useTwoVectors) { + auto out_row_major = raft::make_device_matrix_view(out, N, D); + auto in_row_major = raft::make_device_matrix_view(in, N, D); + + auto out_col_major = raft::make_device_matrix_view(out, N, D); + auto in_col_major = raft::make_device_matrix_view(in, N, D); + + auto apply = bcastAlongRows ? Apply::ALONG_ROWS : Apply::ALONG_COLUMNS; + auto len = bcastAlongRows ? D : N; + auto vec1_view = raft::make_device_vector_view(vec1, len); + auto vec2_view = raft::make_device_vector_view(vec2, len); + if (useTwoVectors) { - matrixVectorOp( - out, - in, - vec1, - vec2, - D, - N, - rowMajor, - bcastAlongRows, - [] __device__(T a, T b, T c) { return a + b + c; }, - stream); + if (rowMajor) { + matrix_vector_op(handle, + in_row_major, + vec1_view, + vec2_view, + out_row_major, + apply, + [] __device__(T a, T b, T c) { return a + b + c; }); + } else { + matrix_vector_op(handle, + in_col_major, + vec1_view, + vec2_view, + out_col_major, + + apply, + [] __device__(T a, T b, T c) { return a + b + c; }); + } } else { - matrixVectorOp( - out, - in, - vec1, - D, - N, - rowMajor, - bcastAlongRows, - [] __device__(T a, T b) { return a + b; }, - stream); + if (rowMajor) { + matrix_vector_op( + handle, in_row_major, vec1_view, out_row_major, apply, [] __device__(T a, T b) { + return a + b; + }); + } else { + matrix_vector_op( + handle, in_col_major, vec1_view, out_col_major, apply, [] __device__(T a, T b) { + return a + b; + }); + } } } @@ -124,7 +143,8 @@ class MatVecOpTest : public ::testing::TestWithParam> (T)1.0, stream); } - matrixVectorOpLaunch(out.data(), + matrixVectorOpLaunch(handle, + out.data(), in.data(), vec1.data(), vec2.data(), @@ -132,9 +152,8 @@ class MatVecOpTest : public ::testing::TestWithParam> N, params.rowMajor, params.bcastAlongRows, - params.useTwoVectors, - stream); - handle.sync_stream(stream); + params.useTwoVectors); + handle.sync_stream(); } protected: diff --git a/cpp/test/linalg/multiply.cu b/cpp/test/linalg/multiply.cu index 852b869676..1d6446c5c0 100644 --- a/cpp/test/linalg/multiply.cu +++ b/cpp/test/linalg/multiply.cu @@ -44,7 +44,10 @@ class MultiplyTest : public ::testing::TestWithParam> { int len = params.len; uniform(handle, r, in.data(), len, T(-1.0), T(1.0)); naiveScale(out_ref.data(), in.data(), params.scalar, len, stream); - multiplyScalar(out.data(), in.data(), params.scalar, len, stream); + auto out_view = raft::make_device_vector_view(out.data(), len); + auto in_view = raft::make_device_vector_view(in.data(), len); + auto scalar_view = raft::make_host_scalar_view(¶ms.scalar); + multiply_scalar(handle, in_view, out_view, scalar_view); handle.sync_stream(stream); } diff --git a/cpp/test/linalg/norm.cu b/cpp/test/linalg/norm.cu index a07e5a8a7a..5243f2435f 100644 --- a/cpp/test/linalg/norm.cu +++ b/cpp/test/linalg/norm.cu @@ -88,12 +88,24 @@ class RowNormTest : public ::testing::TestWithParam> { int rows = params.rows, cols = params.cols, len = rows * cols; uniform(handle, r, data.data(), len, T(-1.0), T(1.0)); naiveRowNorm(dots_exp.data(), data.data(), cols, rows, params.type, params.do_sqrt, stream); + auto output_view = raft::make_device_vector_view(dots_act.data(), params.rows); + auto input_row_major = raft::make_device_matrix_view( + data.data(), params.rows, params.cols); + auto input_col_major = raft::make_device_matrix_view( + data.data(), params.rows, params.cols); if (params.do_sqrt) { - auto fin_op = [] __device__(T in) { return raft::mySqrt(in); }; - rowNorm( - dots_act.data(), data.data(), cols, rows, params.type, params.rowMajor, stream, fin_op); + auto fin_op = [] __device__(const T in) { return raft::mySqrt(in); }; + if (params.rowMajor) { + norm(handle, input_row_major, output_view, params.type, Apply::ALONG_ROWS, fin_op); + } else { + norm(handle, input_col_major, output_view, params.type, Apply::ALONG_ROWS, fin_op); + } } else { - rowNorm(dots_act.data(), data.data(), cols, rows, params.type, params.rowMajor, stream); + if (params.rowMajor) { + norm(handle, input_row_major, output_view, params.type, Apply::ALONG_ROWS); + } else { + norm(handle, input_col_major, output_view, params.type, Apply::ALONG_ROWS); + } } handle.sync_stream(stream); } @@ -152,12 +164,24 @@ class ColNormTest : public ::testing::TestWithParam> { uniform(handle, r, data.data(), len, T(-1.0), T(1.0)); naiveColNorm(dots_exp.data(), data.data(), cols, rows, params.type, params.do_sqrt, stream); + auto output_view = raft::make_device_vector_view(dots_act.data(), params.cols); + auto input_row_major = raft::make_device_matrix_view( + data.data(), params.rows, params.cols); + auto input_col_major = raft::make_device_matrix_view( + data.data(), params.rows, params.cols); if (params.do_sqrt) { - auto fin_op = [] __device__(T in) { return raft::mySqrt(in); }; - colNorm( - dots_act.data(), data.data(), cols, rows, params.type, params.rowMajor, stream, fin_op); + auto fin_op = [] __device__(const T in) { return raft::mySqrt(in); }; + if (params.rowMajor) { + norm(handle, input_row_major, output_view, params.type, Apply::ALONG_COLUMNS, fin_op); + } else { + norm(handle, input_col_major, output_view, params.type, Apply::ALONG_COLUMNS, fin_op); + } } else { - colNorm(dots_act.data(), data.data(), cols, rows, params.type, params.rowMajor, stream); + if (params.rowMajor) { + norm(handle, input_row_major, output_view, params.type, Apply::ALONG_COLUMNS); + } else { + norm(handle, input_col_major, output_view, params.type, Apply::ALONG_COLUMNS); + } } handle.sync_stream(stream); } diff --git a/cpp/test/linalg/power.cu b/cpp/test/linalg/power.cu index e66aa4b4ae..bdab49d5c8 100644 --- a/cpp/test/linalg/power.cu +++ b/cpp/test/linalg/power.cu @@ -97,10 +97,17 @@ class PowerTest : public ::testing::TestWithParam> { naivePowerElem(out_ref.data(), in1.data(), in2.data(), len, stream); naivePowerScalar(out_ref.data(), out_ref.data(), T(2), len, stream); - power(out.data(), in1.data(), in2.data(), len, stream); - powerScalar(out.data(), out.data(), T(2), len, stream); - power(in1.data(), in1.data(), in2.data(), len, stream); - powerScalar(in1.data(), in1.data(), T(2), len, stream); + auto out_view = raft::make_device_vector_view(out.data(), len); + auto in1_view = raft::make_device_vector_view(in1.data(), len); + auto const_out_view = raft::make_device_vector_view(out.data(), len); + auto const_in1_view = raft::make_device_vector_view(in1.data(), len); + auto const_in2_view = raft::make_device_vector_view(in2.data(), len); + const auto scalar = static_cast(2); + auto scalar_view = raft::make_host_scalar_view(&scalar); + power(handle, const_in1_view, const_in2_view, out_view); + power_scalar(handle, const_out_view, out_view, scalar_view); + power(handle, const_in1_view, const_in2_view, in1_view); + power_scalar(handle, const_in1_view, in1_view, scalar_view); handle.sync_stream(); } diff --git a/cpp/test/linalg/reduce.cu b/cpp/test/linalg/reduce.cu index 674cb24069..57654f88ab 100644 --- a/cpp/test/linalg/reduce.cu +++ b/cpp/test/linalg/reduce.cu @@ -52,16 +52,37 @@ void reduceLaunch(OutType* dots, bool inplace, cudaStream_t stream) { - reduce(dots, - data, - cols, - rows, - (OutType)0, - rowMajor, - alongRows, - stream, - inplace, - [] __device__(InType in, int i) { return static_cast(in * in); }); + Apply apply = alongRows ? Apply::ALONG_ROWS : Apply::ALONG_COLUMNS; + int output_size = alongRows ? cols : rows; + + auto output_view_row_major = raft::make_device_vector_view(dots, output_size); + auto input_view_row_major = raft::make_device_matrix_view(data, rows, cols); + + auto output_view_col_major = raft::make_device_vector_view(dots, output_size); + auto input_view_col_major = + raft::make_device_matrix_view(data, rows, cols); + + raft::handle_t handle{stream}; + + if (rowMajor) { + reduce(handle, + input_view_row_major, + output_view_row_major, + (OutType)0, + + apply, + inplace, + [] __device__(InType in, int i) { return static_cast(in * in); }); + } else { + reduce(handle, + input_view_col_major, + output_view_col_major, + (OutType)0, + + apply, + inplace, + [] __device__(InType in, int i) { return static_cast(in * in); }); + } } template diff --git a/cpp/test/linalg/reduce_cols_by_key.cu b/cpp/test/linalg/reduce_cols_by_key.cu index 5d4ea359a3..63afbe2fed 100644 --- a/cpp/test/linalg/reduce_cols_by_key.cu +++ b/cpp/test/linalg/reduce_cols_by_key.cu @@ -84,7 +84,10 @@ class ReduceColsTest : public ::testing::TestWithParam> { uniform(handle, r, in.data(), nrows * ncols, T(-1.0), T(1.0)); uniformInt(handle, r, keys.data(), ncols, 0u, params.nkeys); naiveReduceColsByKey(in.data(), keys.data(), out_ref.data(), nrows, ncols, nkeys, stream); - reduce_cols_by_key(in.data(), keys.data(), out.data(), nrows, ncols, nkeys, stream); + auto input_view = raft::make_device_matrix_view(in.data(), nrows, ncols); + auto output_view = raft::make_device_matrix_view(out.data(), nrows, nkeys); + auto keys_view = raft::make_device_vector_view(keys.data(), ncols); + reduce_cols_by_key(handle, input_view, keys_view, output_view, nkeys); raft::interruptible::synchronize(stream); } diff --git a/cpp/test/linalg/reduce_rows_by_key.cu b/cpp/test/linalg/reduce_rows_by_key.cu index e8baeb5887..e575f37dd6 100644 --- a/cpp/test/linalg/reduce_rows_by_key.cu +++ b/cpp/test/linalg/reduce_rows_by_key.cu @@ -126,21 +126,20 @@ class ReduceRowTest : public ::testing::TestWithParam> { nkeys, out_ref.data(), stream); + auto input_view = raft::make_device_matrix_view( + in.data(), params.cols, static_cast(params.nobs)); + auto output_view = raft::make_device_matrix_view(out.data(), params.cols, params.nkeys); + auto keys_view = raft::make_device_vector_view( + keys.data(), static_cast(params.nobs)); + auto scratch_buf_view = + raft::make_device_vector_view(scratch_buf.data(), static_cast(params.nobs)); + std::optional> weights_view; if (params.weighted) { - reduce_rows_by_key(in.data(), - cols, - keys.data(), - params.weighted ? weight.data() : nullptr, - scratch_buf.data(), - nobs, - cols, - nkeys, - out.data(), - stream); - } else { - reduce_rows_by_key( - in.data(), cols, keys.data(), scratch_buf.data(), nobs, cols, nkeys, out.data(), stream); + weights_view.emplace(weight.data(), static_cast(params.nobs)); } + + reduce_rows_by_key( + handle, input_view, keys_view, output_view, params.nkeys, scratch_buf_view, weights_view); handle.sync_stream(stream); } diff --git a/cpp/test/linalg/rsvd.cu b/cpp/test/linalg/rsvd.cu index 01736615eb..f774d59631 100644 --- a/cpp/test/linalg/rsvd.cu +++ b/cpp/test/linalg/rsvd.cu @@ -124,41 +124,29 @@ class RsvdTest : public ::testing::TestWithParam> { RAFT_CUDA_TRY(cudaMemsetAsync(S.data(), 0, S.size() * sizeof(T), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(V.data(), 0, V.size() * sizeof(T), stream)); + auto A_view = raft::make_device_matrix_view(A.data(), m, n); + std::optional> U_view = + raft::make_device_matrix_view(U.data(), m, params.k); + std::optional> V_view = + raft::make_device_matrix_view(V.data(), params.k, n); + auto S_vec_view = raft::make_device_vector_view(S.data(), params.k); + // RSVD tests if (params.k == 0) { // Test with PC and upsampling ratio - rsvdPerc(handle, - A.data(), - m, - n, - S.data(), - U.data(), - V.data(), - params.PC_perc, - params.UpS_perc, - params.use_bbt, - true, - true, - false, - eig_svd_tol, - max_sweeps, - stream); + if (params.use_bbt) { + rsvd_perc_symmetric( + handle, A_view, S_vec_view, params.PC_perc, params.UpS_perc, U_view, V_view); + } else { + rsvd_perc(handle, A_view, S_vec_view, params.PC_perc, params.UpS_perc, U_view, V_view); + } } else { // Test with directly given fixed rank - rsvdFixedRank(handle, - A.data(), - m, - n, - S.data(), - U.data(), - V.data(), - params.k, - params.p, - params.use_bbt, - true, - true, - true, - eig_svd_tol, - max_sweeps, - stream); + if (params.use_bbt) { + rsvd_fixed_rank_symmetric_jacobi( + handle, A_view, S_vec_view, params.p, eig_svd_tol, max_sweeps, U_view, V_view); + } else { + rsvd_fixed_rank_jacobi( + handle, A_view, S_vec_view, params.p, eig_svd_tol, max_sweeps, U_view, V_view); + } } raft::update_device(A.data(), A_backup_cpu.data(), m * n, stream); } diff --git a/cpp/test/linalg/sqrt.cu b/cpp/test/linalg/sqrt.cu index bb78d9f754..ed57e94914 100644 --- a/cpp/test/linalg/sqrt.cu +++ b/cpp/test/linalg/sqrt.cu @@ -72,9 +72,12 @@ class SqrtTest : public ::testing::TestWithParam> { uniform(handle, r, in1.data(), len, T(1.0), T(2.0)); naiveSqrtElem(out_ref.data(), in1.data(), len); + auto out_view = raft::make_device_vector_view(out.data(), len); + auto in_view = raft::make_device_vector_view(in1.data(), len); + auto in2_view = raft::make_device_vector_view(in1.data(), len); - sqrt(out.data(), in1.data(), len, stream); - sqrt(in1.data(), in1.data(), len, stream); + sqrt(handle, in_view, out_view); + sqrt(handle, in_view, in2_view); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } diff --git a/cpp/test/linalg/strided_reduction.cu b/cpp/test/linalg/strided_reduction.cu index c4f02310a5..39e2764def 100644 --- a/cpp/test/linalg/strided_reduction.cu +++ b/cpp/test/linalg/strided_reduction.cu @@ -34,8 +34,11 @@ struct stridedReductionInputs { template void stridedReductionLaunch(T* dots, const T* data, int cols, int rows, cudaStream_t stream) { - stridedReduction( - dots, data, cols, rows, (T)0, stream, false, [] __device__(T in, int i) { return in * in; }); + raft::handle_t handle{stream}; + auto dots_view = raft::make_device_vector_view(dots, cols); + auto data_view = raft::make_device_matrix_view(data, rows, cols); + strided_reduction( + handle, data_view, dots_view, (T)0, false, [] __device__(T in, int i) { return in * in; }); } template diff --git a/cpp/test/linalg/subtract.cu b/cpp/test/linalg/subtract.cu index 455f5e6c30..3904f9f33f 100644 --- a/cpp/test/linalg/subtract.cu +++ b/cpp/test/linalg/subtract.cu @@ -92,10 +92,18 @@ class SubtractTest : public ::testing::TestWithParam> { naiveSubtractElem(out_ref.data(), in1.data(), in2.data(), len, stream); naiveSubtractScalar(out_ref.data(), out_ref.data(), T(1), len, stream); - subtract(out.data(), in1.data(), in2.data(), len, stream); - subtractScalar(out.data(), out.data(), T(1), len, stream); - subtract(in1.data(), in1.data(), in2.data(), len, stream); - subtractScalar(in1.data(), in1.data(), T(1), len, stream); + auto out_view = raft::make_device_vector_view(out.data(), len); + auto in1_view = raft::make_device_vector_view(in1.data(), len); + auto const_out_view = raft::make_device_vector_view(out.data(), len); + auto const_in1_view = raft::make_device_vector_view(in1.data(), len); + auto const_in2_view = raft::make_device_vector_view(in2.data(), len); + const auto scalar = static_cast(1); + auto scalar_view = raft::make_host_scalar_view(&scalar); + + subtract(handle, const_in1_view, const_in2_view, out_view); + subtract_scalar(handle, const_out_view, out_view, scalar_view); + subtract(handle, const_in1_view, const_in2_view, in1_view); + subtract_scalar(handle, const_in1_view, in1_view, scalar_view); handle.sync_stream(stream); } diff --git a/cpp/test/linalg/svd.cu b/cpp/test/linalg/svd.cu index 292793478c..c18417dc9e 100644 --- a/cpp/test/linalg/svd.cu +++ b/cpp/test/linalg/svd.cu @@ -78,17 +78,22 @@ class SvdTest : public ::testing::TestWithParam> { raft::update_device(right_eig_vectors_ref.data(), right_eig_vectors_ref_h, right_evl, stream); raft::update_device(sing_vals_ref.data(), sing_vals_ref_h, params.n_col, stream); - svdQR(handle, - data.data(), - params.n_row, - params.n_col, - sing_vals_qr.data(), - left_eig_vectors_qr.data(), - right_eig_vectors_trans_qr.data(), - true, - true, - true, - stream); + auto data_view = raft::make_device_matrix_view( + data.data(), params.n_row, params.n_col); + auto sing_vals_qr_view = + raft::make_device_vector_view(sing_vals_qr.data(), params.n_col); + std::optional> left_eig_vectors_qr_view = + raft::make_device_matrix_view( + left_eig_vectors_qr.data(), params.n_row, params.n_col); + std::optional> + right_eig_vectors_trans_qr_view = raft::make_device_matrix_view( + right_eig_vectors_trans_qr.data(), params.n_col, params.n_col); + + svd_qr_transpose_right_vec(handle, + data_view, + sing_vals_qr_view, + left_eig_vectors_qr_view, + right_eig_vectors_trans_qr_view); handle.sync_stream(stream); } diff --git a/cpp/test/linalg/ternary_op.cu b/cpp/test/linalg/ternary_op.cu index 21573eff48..e172d771cd 100644 --- a/cpp/test/linalg/ternary_op.cu +++ b/cpp/test/linalg/ternary_op.cu @@ -63,10 +63,16 @@ class ternaryOpTest : public ::testing::TestWithParam> { fill(handle, rng, in2.data(), len, T(2.0)); fill(handle, rng, in3.data(), len, T(3.0)); - auto add = [] __device__(T a, T b, T c) { return a + b + c; }; - auto mul = [] __device__(T a, T b, T c) { return a * b * c; }; - ternaryOp(out_add.data(), in1.data(), in2.data(), in3.data(), len, add, stream); - ternaryOp(out_mul.data(), in1.data(), in2.data(), in3.data(), len, mul, stream); + auto add = [] __device__(T a, T b, T c) { return a + b + c; }; + auto mul = [] __device__(T a, T b, T c) { return a * b * c; }; + auto out_add_view = raft::make_device_vector_view(out_add.data(), len); + auto out_mul_view = raft::make_device_vector_view(out_mul.data(), len); + auto in1_view = raft::make_device_vector_view(in1.data(), len); + auto in2_view = raft::make_device_vector_view(in2.data(), len); + auto in3_view = raft::make_device_vector_view(in3.data(), len); + + ternary_op(handle, in1_view, in2_view, in3_view, out_add_view, add); + ternary_op(handle, in1_view, in2_view, in3_view, out_mul_view, mul); } protected: diff --git a/cpp/test/linalg/unary_op.cu b/cpp/test/linalg/unary_op.cu index 4174056170..57b009a0ac 100644 --- a/cpp/test/linalg/unary_op.cu +++ b/cpp/test/linalg/unary_op.cu @@ -30,14 +30,18 @@ namespace linalg { template void unaryOpLaunch(OutType* out, const InType* in, InType scalar, IdxType len, cudaStream_t stream) { + raft::handle_t handle{stream}; + auto out_view = raft::make_device_vector_view(out, len); + auto in_view = raft::make_device_vector_view(in, len); if (in == nullptr) { auto op = [scalar] __device__(OutType * ptr, IdxType idx) { *ptr = static_cast(scalar * idx); }; - writeOnlyUnaryOp(out, len, op, stream); + + write_only_unary_op(handle, out_view, op); } else { auto op = [scalar] __device__(InType in) { return static_cast(in * scalar); }; - unaryOp(out, in, len, op, stream); + unary_op(handle, in_view, out_view, op); } } diff --git a/cpp/test/mdspan_utils.cu b/cpp/test/mdspan_utils.cu index 5683c0267a..7f1efb78bb 100644 --- a/cpp/test/mdspan_utils.cu +++ b/cpp/test/mdspan_utils.cu @@ -55,17 +55,16 @@ void test_template_asserts() static_assert(is_mdspan_v, "Derived device mdspan type is not mdspan"); // Checking if types are device_mdspan - static_assert(is_device_accessible_mdspan_v>, + static_assert(is_device_mdspan_v>, "device_matrix_view type not a device_mdspan"); - static_assert(!is_device_accessible_mdspan_v>, + static_assert(!is_device_mdspan_v>, "host_matrix_view type is a device_mdspan"); - static_assert(is_device_accessible_mdspan_v, - "Derived device mdspan type is not device_mdspan"); + static_assert(is_device_mdspan_v, "Derived device mdspan type is not device_mdspan"); // Checking if types are host_mdspan - static_assert(!is_host_accessible_mdspan_v>, + static_assert(!is_host_mdspan_v>, "device_matrix_view type is a host_mdspan"); - static_assert(is_host_accessible_mdspan_v>, + static_assert(is_host_mdspan_v>, "host_matrix_view type is not a host_mdspan"); // checking variadics