diff --git a/cpp/bench/prims/distance/fused_l2_nn.cu b/cpp/bench/prims/distance/fused_l2_nn.cu index 1c45572782..a5115407dd 100644 --- a/cpp/bench/prims/distance/fused_l2_nn.cu +++ b/cpp/bench/prims/distance/fused_l2_nn.cu @@ -16,6 +16,7 @@ #include #include +#include #include #if defined RAFT_COMPILED #include diff --git a/cpp/include/raft/core/detail/nvtx.hpp b/cpp/include/raft/core/detail/nvtx.hpp index adbf3a3666..e0f985cb73 100644 --- a/cpp/include/raft/core/detail/nvtx.hpp +++ b/cpp/include/raft/core/detail/nvtx.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index f72ae36d64..1b9992212e 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -259,6 +259,36 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col return device_matrix_view{ptr, extents}; } +/** + * @brief Create a 2-dim mdspan instance for device pointer with a strided layout + * that is restricted to stride 1 in the trailing dimension. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr on device to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer + * @param[in] stride leading dimension / stride of data + */ +template +auto make_device_strided_matrix_view(ElementType* ptr, + IndexType n_rows, + IndexType n_cols, + IndexType stride) +{ + constexpr auto is_row_major = std::is_same_v; + IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1; + IndexType stride1 = is_row_major ? 1 : (stride > 0 ? stride : n_rows); + + assert(is_row_major ? stride0 >= n_cols : stride1 >= n_rows); + matrix_extent extents{n_rows, n_cols}; + + auto layout = make_strided_layout(extents, std::array{stride0, stride1}); + return device_matrix_view{ptr, layout}; +} + /** * @brief Create a 1-dim mdspan instance for device pointer. * @tparam ElementType the data type of the vector elements diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index aaf3052892..a68b904470 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -16,13 +16,26 @@ #pragma once +#include +#include #include +#include +//#include +#include +#include #include #include namespace raft::distance::kernels::detail { +template +using dense_input_matrix_view_t = raft::device_matrix_view; +template +using dense_output_matrix_view_t = raft::device_matrix_view; +template +using csr_input_matrix_view_t = raft::device_csr_matrix_view; + /** * Base class for general Gram matrices * A Gram matrix is the Hermitian matrix of inner probucts G_ik = @@ -37,14 +50,135 @@ namespace raft::distance::kernels::detail { */ template class GramMatrixBase { + protected: cublasHandle_t cublas_handle; + bool legacy_interface; public: - GramMatrixBase(cublasHandle_t cublas_handle) : cublas_handle(cublas_handle){}; + GramMatrixBase() : legacy_interface(false){}; + [[deprecated]] GramMatrixBase(cublasHandle_t cublas_handle) + : cublas_handle(cublas_handle), legacy_interface(true){}; virtual ~GramMatrixBase(){}; /** Convenience function to evaluate the Gram matrix for two vector sets. + * Vector sets are provided in Matrix format + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void operator()(raft::device_resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1 = nullptr, + math_t* norm_x2 = nullptr) + { + evaluate(handle, x1, x2, out, norm_x1, norm_x2); + } + + /** Convenience function to evaluate the Gram matrix for two vector sets. + * Vector sets are provided in Matrix format + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void operator()(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1 = nullptr, + math_t* norm_x2 = nullptr) + { + evaluate(handle, x1, x2, out, norm_x1, norm_x2); + } + + /** Convenience function to evaluate the Gram matrix for two vector sets. + * Vector sets are provided in Matrix format + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void operator()(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1 = nullptr, + math_t* norm_x2 = nullptr) + { + evaluate(handle, x1, x2, out, norm_x1, norm_x2); + } + + // unfortunately, 'evaluate' cannot be templatized as it needs to be virtual + + /** Evaluate the Gram matrix for two vector sets using simple dot product. + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + virtual void evaluate(raft::device_resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + linear(handle, x1, x2, out); + } + /** Evaluate the Gram matrix for two vector sets using simple dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + virtual void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + linear(handle, x1, x2, out); + } + /** Evaluate the Gram matrix for two vector sets using simple dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + virtual void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + linear(handle, x1, x2, out); + } + + /** Evaluate the Gram matrix for two vector sets using simple dot product. * * @param [in] x1 device array of vectors, size [n1*n_cols] * @param [in] n1 number vectors in x1 @@ -55,29 +189,26 @@ class GramMatrixBase { * @param [in] is_row_major whether the input and output matrices are in row * major format * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 - * @param ld2 leading dimension of x2 - * @param ld_out leading dimension of out + * @param ld1 leading dimension of x1 (usually it is n1) + * @param ld2 leading dimension of x2 (usually it is n2) + * @param ld_out leading dimension of out (usually it is n1) */ - virtual void operator()(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1 = 0, - int ld2 = 0, - int ld_out = 0) + [[deprecated]] virtual void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) { - if (ld1 <= 0) { ld1 = is_row_major ? n_cols : n1; } - if (ld2 <= 0) { ld2 = is_row_major ? n_cols : n2; } - if (ld_out <= 0) { ld_out = is_row_major ? n2 : n1; } - evaluate(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + linear(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); } - /** Evaluate the Gram matrix for two vector sets using simple dot product. + /** Convenience function to evaluate the Gram matrix for two vector sets. * * @param [in] x1 device array of vectors, size [n1*n_cols] * @param [in] n1 number vectors in x1 @@ -88,30 +219,30 @@ class GramMatrixBase { * @param [in] is_row_major whether the input and output matrices are in row * major format * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 (usually it is n1) - * @param ld2 leading dimension of x2 (usually it is n2) - * @param ld_out leading dimension of out (usually it is n1) + * @param ld1 leading dimension of x1 + * @param ld2 leading dimension of x2 + * @param ld_out leading dimension of out */ - virtual void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) + [[deprecated]] void operator()(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1 = 0, + int ld2 = 0, + int ld_out = 0) { - linear(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + ASSERT(legacy_interface, "Legacy interface can only be used with legacy ctor."); + if (ld1 <= 0) { ld1 = is_row_major ? n_cols : n1; } + if (ld2 <= 0) { ld2 = is_row_major ? n_cols : n2; } + if (ld_out <= 0) { ld_out = is_row_major ? n2 : n1; } + evaluate(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); } - // private: - // The following methods should be private, they are kept public to avoid: - // "error: The enclosing parent function ("distance") for an extended - // __device__ lambda cannot have private or protected access within its class" - + protected: /** Calculates the Gram matrix using simple dot product between vector sets. * * out = x1 * x2 @@ -131,17 +262,17 @@ class GramMatrixBase { * @param ld2 leading dimension of x2 * @param ld_out leading dimension of out */ - void linear(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) + [[deprecated]] void linear(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) { math_t alpha = 1.0; math_t beta = 0.0; @@ -182,37 +313,198 @@ class GramMatrixBase { } } - /** Calculates the Gram matrix using Euclidean distance. + protected: + bool get_is_row_major(dense_output_matrix_view_t matrix) + { + return (matrix.stride(1) == 1); + } + + bool get_is_row_major(dense_input_matrix_view_t matrix) + { + return (matrix.stride(1) == 1); + } + + bool get_is_col_major(dense_output_matrix_view_t matrix) + { + return (matrix.stride(0) == 1); + } + + bool get_is_col_major(dense_input_matrix_view_t matrix) + { + return (matrix.stride(0) == 1); + } + + /** Calculates the Gram matrix using simple dot product between vector sets. + * + * out = x1 * x2 * * Can be used as a building block for more complex kernel functions. * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 - * @param ld2 leading dimension of x2 - * @param ld_out leading dimension of out + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] */ - virtual void distance(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { - raft::distance::distance( - raft::device_resources(stream), x1, x2, out, n1, n2, n_cols, is_row_major); + void linear(raft::device_resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out) + { + // check is_row_major consistency + bool is_row_major = get_is_row_major(x1) && get_is_row_major(x2) && get_is_row_major(out); + bool is_col_major = get_is_col_major(x1) && get_is_col_major(x2) && get_is_col_major(out); + ASSERT(is_row_major || is_col_major, + "GramMatrix leading dimensions for x1, x2 and out do not match"); + + // check dimensions + int n1 = out.extent(0); + int n2 = out.extent(1); + int n_cols = x1.extent(1); + ASSERT(x1.extent(0) == n1, "GramMatrix input matrix dimensions for x1 and out do not match"); + ASSERT(x2.extent(0) == n2, "GramMatrix input matrix dimensions for x2 and out do not match"); + ASSERT(x2.extent(1) == n_cols, "GramMatrix input matrix dimensions for x1 and x2 do not match"); + + // extract major stride + int ld1 = is_row_major ? x1.stride(0) : x1.stride(1); + int ld2 = is_row_major ? x2.stride(0) : x2.stride(1); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + + math_t alpha = 1.0; + math_t beta = 0.0; + if (is_row_major) { + // #TODO: Use mdspan-based API when stride-capable + // https://github.com/rapidsai/raft/issues/875 + raft::linalg::gemm(handle, + true, + false, + n2, + n1, + n_cols, + &alpha, + x2.data_handle(), + ld2, + x1.data_handle(), + ld1, + &beta, + out.data_handle(), + ld_out, + handle.get_stream()); + } else { + // #TODO: Use mdspan-based API when stride-capable + // https://github.com/rapidsai/raft/issues/875 + raft::linalg::gemm(handle, + false, + true, + n1, + n2, + n_cols, + &alpha, + x1.data_handle(), + ld1, + x2.data_handle(), + ld2, + &beta, + out.data_handle(), + ld_out, + handle.get_stream()); + } + } + + /** Calculates the Gram matrix using simple dot product between vector sets. + * + * out = x1 * x2 + * + * Can be used as a building block for more complex kernel functions. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + */ + void linear(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out) + { + // check is_row_major consistency + bool is_row_major = get_is_row_major(x2) && get_is_row_major(out); + bool is_col_major = get_is_col_major(x2) && get_is_col_major(out); + ASSERT(is_row_major || is_col_major, + "GramMatrix leading dimensions for x2 and out do not match"); + + // check dimensions + auto x1_structure = x1.structure_view(); + ASSERT(x1_structure.get_n_rows() == out.extent(0), + "GramMatrix input matrix dimensions for x1 and out do not match"); + ASSERT(x2.extent(0) == out.extent(1), + "GramMatrix input matrix dimensions for x2 and out do not match"); + ASSERT(x2.extent(1) == x1_structure.get_n_cols(), + "GramMatrix input matrix dimensions for x1 and x2 do not match"); + + math_t alpha = 1.0; + math_t beta = 0.0; + + raft::sparse::linalg::spmm(handle, false, true, &alpha, x1, x2, &beta, out); + } + + /** Calculates the Gram matrix using simple dot product between vector sets. + * + * out = x1 * x2 + * + * Can be used as a building block for more complex kernel functions. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + */ + void linear(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out) + { + // check is_row_major consistency + bool is_row_major = get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + int minor_out = is_row_major ? out.extent(1) : out.extent(0); + ASSERT(ld_out == minor_out, "Sparse linear Kernel distance does not support ld_out parameter"); + + auto x1_structure = x1.structure_view(); + auto x2_structure = x2.structure_view(); + raft::sparse::distance::distances_config_t dist_config(handle); + + // switch a,b based on is_row_major + if (!is_row_major) { + dist_config.a_nrows = x2_structure.get_n_rows(); + dist_config.a_ncols = x2_structure.get_n_cols(); + dist_config.a_nnz = x2_structure.get_nnz(); + dist_config.a_indptr = const_cast(x2_structure.get_indptr().data()); + dist_config.a_indices = const_cast(x2_structure.get_indices().data()); + dist_config.a_data = const_cast(x2.get_elements().data()); + dist_config.b_nrows = x1_structure.get_n_rows(); + dist_config.b_ncols = x1_structure.get_n_cols(); + dist_config.b_nnz = x1_structure.get_nnz(); + dist_config.b_indptr = const_cast(x1_structure.get_indptr().data()); + dist_config.b_indices = const_cast(x1_structure.get_indices().data()); + dist_config.b_data = const_cast(x1.get_elements().data()); + } else { + dist_config.a_nrows = x1_structure.get_n_rows(); + dist_config.a_ncols = x1_structure.get_n_cols(); + dist_config.a_nnz = x1_structure.get_nnz(); + dist_config.a_indptr = const_cast(x1_structure.get_indptr().data()); + dist_config.a_indices = const_cast(x1_structure.get_indices().data()); + dist_config.a_data = const_cast(x1.get_elements().data()); + dist_config.b_nrows = x2_structure.get_n_rows(); + dist_config.b_ncols = x2_structure.get_n_cols(); + dist_config.b_nnz = x2_structure.get_nnz(); + dist_config.b_indptr = const_cast(x2_structure.get_indptr().data()); + dist_config.b_indices = const_cast(x2_structure.get_indices().data()); + dist_config.b_data = const_cast(x2.get_elements().data()); + } + + raft::sparse::distance::pairwiseDistance( + out.data_handle(), dist_config, raft::distance::DistanceType::InnerProduct, 0.0); } }; + }; // end namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh b/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh index 1aa6809bcd..bb3ff1c2f5 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,19 +26,35 @@ namespace raft::distance::kernels::detail { template class KernelFactory { public: - static GramMatrixBase* create(KernelParams params, cublasHandle_t cublas_handle) + static GramMatrixBase* create(KernelParams params) { GramMatrixBase* res; // KernelParams is not templated, we convert the parameters to math_t here: math_t coef0 = params.coef0; math_t gamma = params.gamma; switch (params.kernel) { - case LINEAR: res = new GramMatrixBase(cublas_handle); break; + case LINEAR: res = new GramMatrixBase(); break; + case POLYNOMIAL: res = new PolynomialKernel(params.degree, gamma, coef0); break; + case TANH: res = new TanhKernel(gamma, coef0); break; + case RBF: res = new RBFKernel(gamma); break; + default: throw raft::exception("Kernel not implemented"); + } + return res; + } + + [[deprecated]] static GramMatrixBase* create(KernelParams params, cublasHandle_t handle) + { + GramMatrixBase* res; + // KernelParams is not templated, we convert the parameters to math_t here: + math_t coef0 = params.coef0; + math_t gamma = params.gamma; + switch (params.kernel) { + case LINEAR: res = new GramMatrixBase(handle); break; case POLYNOMIAL: - res = new PolynomialKernel(params.degree, gamma, coef0, cublas_handle); + res = new PolynomialKernel(params.degree, gamma, coef0, handle); break; - case TANH: res = new TanhKernel(gamma, coef0, cublas_handle); break; - case RBF: res = new RBFKernel(gamma); break; + case TANH: res = new TanhKernel(gamma, coef0, handle); break; + case RBF: res = new RBFKernel(gamma, handle); break; default: throw raft::exception("Kernel not implemented"); } return res; diff --git a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh index d1465efdb0..4b000add21 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh @@ -21,6 +21,7 @@ #include #include +#include namespace raft::distance::kernels::detail { @@ -100,6 +101,38 @@ __global__ void tanh_kernel(math_t* inout, int ld, int rows, int cols, math_t ga } } +/** Epiloge function for rbf kernel using expansion. + * + * Calculates output_ij = exp(-gain * (norm_x_i + norm_y_j - 2*input_ij)); + * + * Intended usage + * - input is the product of two matrices X and Y input_ij = sum_k X_ik * Y_jk + * - norm_x_i = l2_norm(x_i), where x_i is the i-th row of matrix X + * - norm_y_j = l2_norm(y_j), where y_j is the j-th row of matrix Y + * + * @param inout device vector in column major format, size [ld * cols] + * @param ld leading dimension of the inout buffer + * @param rows number of rows (rows <= ld) + * @param cols number of columns + * @param norm_x l2-norm of X's rows + * @param norm_y l2-norm of Y's rows + * @param gain + */ +template +__global__ void rbf_kernel_expanded( + math_t* inout, int ld, int rows, int cols, math_t* norm_x, math_t* norm_y, math_t gain) +{ + for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; + tidy += blockDim.y * gridDim.y) { + math_t norm_y_val = norm_y[tidy]; + for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; + tidx += blockDim.x * gridDim.x) { + inout[tidx + tidy * ld] = + exp(-1.0 * gain * (norm_x[tidx] + norm_y_val - inout[tidx + tidy * ld] * 2)); + } + } +} + /** * Create a kernel matrix using polynomial kernel function. */ @@ -138,11 +171,69 @@ class PolynomialKernel : public GramMatrixBase { * @param exponent * @param gain * @param offset - * @param cublas_handle */ - PolynomialKernel(exp_t exponent, math_t gain, math_t offset, cublasHandle_t cublas_handle) - : GramMatrixBase(cublas_handle), exponent(exponent), gain(gain), offset(offset) + PolynomialKernel(exp_t exponent, math_t gain, math_t offset) + : GramMatrixBase(), exponent(exponent), gain(gain), offset(offset) + { + } + + [[deprecated]] PolynomialKernel(exp_t exponent, math_t gain, math_t offset, cublasHandle_t handle) + : GramMatrixBase(handle), exponent(exponent), gain(gain), offset(offset) + { + } + + /** Evaluate kernel matrix using polynomial kernel. + * + * output[i,k] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::device_resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel( + out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + } + + /** Evaluate kernel matrix using polynomial kernel. + * + * output[i,k] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel( + out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); } /** Evaluate kernel matrix using polynomial kernel. @@ -150,32 +241,57 @@ class PolynomialKernel : public GramMatrixBase { * output[i,k] = (gain* + offset)^exponent, * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel( + out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + } + + /** Evaluate the Gram matrix using the legacy interface. * * @param [in] x1 device array of vectors, size [n1*n_cols] * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of features in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*cols] + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] * @param [in] n2 number vectors in x2 * @param [out] out device buffer to store the Gram matrix, size [n1*n2] * @param [in] is_row_major whether the input and output matrices are in row * major format * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 - * @param ld2 leading dimension of x2 - * @param ld_out leading dimension of out + * @param ld1 leading dimension of x1 (usually it is n1) + * @param ld2 leading dimension of x2 (usually it is n2) + * @param ld_out leading dimension of out (usually it is n1) */ - void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) + [[deprecated]] void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) { + ASSERT(GramMatrixBase::legacy_interface, + "Legacy interface can only be used with legacy ctor."); GramMatrixBase::linear( x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); applyKernel(out, ld_out, n1, n2, is_row_major, stream); @@ -216,10 +332,11 @@ class TanhKernel : public GramMatrixBase { * @tparam math_t floating point type * @param gain * @param offset - * @param cublas_handle */ - TanhKernel(math_t gain, math_t offset, cublasHandle_t cublas_handle) - : GramMatrixBase(cublas_handle), gain(gain), offset(offset) + TanhKernel(math_t gain, math_t offset) : GramMatrixBase(), gain(gain), offset(offset) {} + + [[deprecated]] TanhKernel(math_t gain, math_t offset, cublasHandle_t handle) + : GramMatrixBase(handle), gain(gain), offset(offset) { } @@ -229,12 +346,87 @@ class TanhKernel : public GramMatrixBase { * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector * in the x2 set, and < , > denotes dot product. * - * @param [in] x1 device array of vectors, - * size [n1*n_cols] + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::device_resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel( + out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + } + + /** Evaluate kernel matrix using tanh kernel. + * + * output_[i + k*n1] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel( + out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + } + + /** Evaluate kernel matrix using tanh kernel. + * + * output_[i + k*n1] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel( + out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + } + + /** Evaluate the Gram matrix using the legacy interface. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of features in x1 and x2 - * @param [in] x2 device array of vectors, - * size [n2*n_cols] + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] * @param [in] n2 number vectors in x2 * @param [out] out device buffer to store the Gram matrix, size [n1*n2] * @param [in] is_row_major whether the input and output matrices are in row @@ -244,18 +436,20 @@ class TanhKernel : public GramMatrixBase { * @param ld2 leading dimension of x2 (usually it is n2) * @param ld_out leading dimension of out (usually it is n1) */ - void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) + [[deprecated]] void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) { + ASSERT(GramMatrixBase::legacy_interface, + "Legacy interface can only be used with legacy ctor."); GramMatrixBase::linear( x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); applyKernel(out, ld_out, n1, n2, is_row_major, stream); @@ -269,21 +463,23 @@ template class RBFKernel : public GramMatrixBase { math_t gain; - void applyKernel( - math_t* inout, int ld, int rows, int cols, bool is_row_major, cudaStream_t stream) + void applyKernel(math_t* inout, + int ld, + int rows, + int cols, + math_t* norm_x1, + math_t* norm_x2, + bool is_row_major, + cudaStream_t stream) { - const int n_minor = is_row_major ? cols : rows; - if (ld == n_minor) { - rbf_kernel_nopad<<((size_t)rows * cols, 128), 128, 0, stream>>>( - inout, rows * cols, gain); - } else { - int n1 = is_row_major ? cols : rows; - int n2 = is_row_major ? rows : cols; - rbf_kernel<<>>(inout, ld, n1, n2, gain); - } + int n1 = is_row_major ? cols : rows; + int n2 = is_row_major ? rows : cols; + math_t* norm_n1 = is_row_major ? norm_x2 : norm_x1; + math_t* norm_n2 = is_row_major ? norm_x1 : norm_x2; + rbf_kernel_expanded<<>>(inout, ld, n1, n2, norm_n1, norm_n2, gain); } public: @@ -295,61 +491,230 @@ class RBFKernel : public GramMatrixBase { * @tparam math_t floating point type * @param gain */ - RBFKernel(math_t gain) : GramMatrixBase(NULL), gain(gain) {} + RBFKernel(math_t gain) : GramMatrixBase(), gain(gain) {} + + [[deprecated]] RBFKernel(math_t gain, cublasHandle_t handle) + : GramMatrixBase(handle), gain(gain) + { + } + + void matrixRowNormL2(raft::device_resources const& handle, + dense_input_matrix_view_t matrix, + math_t* target) + { + bool is_row_major = GramMatrixBase::get_is_row_major(matrix); + int minor = is_row_major ? matrix.extent(1) : matrix.extent(0); + int ld = is_row_major ? matrix.stride(0) : matrix.stride(1); + ASSERT(ld == minor, "RBF Kernel lazy rowNorm compute does not support ld parameter"); + raft::linalg::rowNorm(target, + matrix.data_handle(), + matrix.extent(1), + matrix.extent(0), + raft::linalg::NormType::L2Norm, + is_row_major, + handle.get_stream()); + } + + void matrixRowNormL2(raft::device_resources const& handle, + csr_input_matrix_view_t matrix, + math_t* target) + { + auto matrix_structure = matrix.structure_view(); + raft::sparse::linalg::rowNormCsr(handle, + matrix_structure.get_indptr().data(), + matrix.get_elements().data(), + matrix_structure.get_nnz(), + matrix_structure.get_n_rows(), + target, + raft::linalg::NormType::L2Norm); + } /** Evaluate kernel matrix using RBF kernel. * * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector * in the x2 set, and | | euclidean distance. + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void evaluate(raft::device_resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + cudaStream_t stream = handle.get_stream(); + + // lazy compute norms if not given + rmm::device_uvector tmp_norm_x1(0, stream); + rmm::device_uvector tmp_norm_x2(0, stream); + if (norm_x1 == nullptr) { + tmp_norm_x1.reserve(x1.extent(0), stream); + norm_x1 = tmp_norm_x1.data(); + matrixRowNormL2(handle, x1, norm_x1); + } + if (norm_x2 == nullptr) { + tmp_norm_x2.reserve(x2.extent(0), stream); + norm_x2 = tmp_norm_x2.data(); + matrixRowNormL2(handle, x2, norm_x2); + } + + // compute L2expanded + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + norm_x1, + norm_x2, + is_row_major, + handle.get_stream()); + } + + /** Evaluate kernel matrix using RBF kernel. + * + * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and | | euclidean distance. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + cudaStream_t stream = handle.get_stream(); + + // lazy compute norms if not given + rmm::device_uvector tmp_norm_x1(0, stream); + rmm::device_uvector tmp_norm_x2(0, stream); + if (norm_x1 == nullptr) { + tmp_norm_x1.reserve(x1.structure_view().get_n_rows(), stream); + norm_x1 = tmp_norm_x1.data(); + matrixRowNormL2(handle, x1, norm_x1); + } + if (norm_x2 == nullptr) { + tmp_norm_x2.reserve(x2.extent(0), stream); + norm_x2 = tmp_norm_x2.data(); + matrixRowNormL2(handle, x2, norm_x2); + } + + // compute L2expanded + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + norm_x1, + norm_x2, + is_row_major, + handle.get_stream()); + } + + /** Evaluate kernel matrix using RBF kernel. + * + * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and | | euclidean distance. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + cudaStream_t stream = handle.get_stream(); + + // lazy compute norms if not given + rmm::device_uvector tmp_norm_x1(0, stream); + rmm::device_uvector tmp_norm_x2(0, stream); + if (norm_x1 == nullptr) { + tmp_norm_x1.reserve(x1.structure_view().get_n_rows(), stream); + norm_x1 = tmp_norm_x1.data(); + matrixRowNormL2(handle, x1, norm_x1); + } + if (norm_x2 == nullptr) { + tmp_norm_x2.reserve(x2.structure_view().get_n_rows(), stream); + norm_x2 = tmp_norm_x2.data(); + matrixRowNormL2(handle, x2, norm_x2); + } + + // compute L2expanded + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + norm_x1, + norm_x2, + is_row_major, + handle.get_stream()); + } + + /** Evaluate the Gram matrix using the legacy interface. * * @param [in] x1 device array of vectors, size [n1*n_cols] * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of features in x1 and x2 + * @param [in] n_cols number of columns (features) in x1 and x2 * @param [in] x2 device array of vectors, size [n2*n_cols] * @param [in] n2 number vectors in x2 * @param [out] out device buffer to store the Gram matrix, size [n1*n2] * @param [in] is_row_major whether the input and output matrices are in row * major format * @param [in] stream cuda stream - * @param ld1 leading dimension of x1, currently only ld1 == n1 is supported - * @param ld2 leading dimension of x2, currently only ld2 == n2 is supported - * @param ld_out leading dimension of out, only ld_out == n1 is supported + * @param ld1 leading dimension of x1 (usually it is n1) + * @param ld2 leading dimension of x2 (usually it is n2) + * @param ld_out leading dimension of out (usually it is n1) */ - void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) + [[deprecated]] void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) { + ASSERT(GramMatrixBase::legacy_interface, + "Legacy interface can only be used with legacy ctor."); int minor1 = is_row_major ? n_cols : n1; int minor2 = is_row_major ? n_cols : n2; int minor_out = is_row_major ? n2 : n1; ASSERT(ld1 == minor1, "RBF Kernel distance does not support ld1 parameter"); ASSERT(ld2 == minor2, "RBF Kernel distance does not support ld2 parameter"); ASSERT(ld_out == minor_out, "RBF Kernel distance does not support ld_out parameter"); - distance(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); - } - /** Customize distance function withe RBF epilogue */ - void distance(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { math_t gain = this->gain; using index_t = int64_t; diff --git a/cpp/include/raft/sparse/linalg/detail/norm.cuh b/cpp/include/raft/sparse/linalg/detail/norm.cuh index c2a8aa4246..56ca2ebfa7 100644 --- a/cpp/include/raft/sparse/linalg/detail/norm.cuh +++ b/cpp/include/raft/sparse/linalg/detail/norm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,15 @@ #pragma once #include +#include +#include +#include #include #include #include +#include + #include #include @@ -170,6 +175,62 @@ void csr_row_normalize_max(const int* ia, // csr row ind array (sorted by row) RAFT_CUDA_TRY(cudaGetLastError()); } +template +void csr_row_op_wrapper(const IdxType* ia, + const Type* data, + IdxType nnz, + IdxType N, + Type init, + Type* norm, + cudaStream_t stream, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + op::csr_row_op( + ia, + N, + nnz, + [data, init, norm, main_op, reduce_op, final_op] __device__( + IdxType row, IdxType start_idx, IdxType stop_idx) { + norm[row] = init; + for (IdxType i = start_idx; i < stop_idx; i++) + norm[row] = final_op(reduce_op(norm[row], main_op(data[i]))); + }, + stream); +} + +template +void rowNormCsrCaller(const IdxType* ia, + const Type* data, + IdxType nnz, + IdxType N, + Type* norm, + raft::linalg::NormType type, + Lambda fin_op, + cudaStream_t stream) +{ + switch (type) { + case raft::linalg::NormType::L1Norm: + csr_row_op_wrapper( + ia, data, nnz, N, (Type)0, norm, stream, raft::abs_op(), raft::add_op(), fin_op); + break; + case raft::linalg::NormType::L2Norm: + csr_row_op_wrapper( + ia, data, nnz, N, (Type)0, norm, stream, raft::sq_op(), raft::add_op(), fin_op); + break; + case raft::linalg::NormType::LinfNorm: + csr_row_op_wrapper( + ia, data, nnz, N, (Type)0, norm, stream, raft::abs_op(), raft::max_op(), fin_op); + break; + default: THROW("Unsupported norm type: %d", type); + }; +} + }; // end NAMESPACE detail }; // end NAMESPACE linalg }; // end NAMESPACE sparse diff --git a/cpp/include/raft/sparse/linalg/detail/spmm.hpp b/cpp/include/raft/sparse/linalg/detail/spmm.hpp new file mode 100644 index 0000000000..b61b561a12 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/detail/spmm.hpp @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft { +namespace sparse { +namespace linalg { +namespace detail { + +/** + * @brief determine common data layout for both dense matrices + * @tparam ValueType Data type of Y,Z (float/double) + * @tparam IndexType Type of Y,Z + * @tparam LayoutPolicyY layout of Y + * @tparam LayoutPolicyZ layout of Z + * @param[in] x input raft::device_matrix_view + * @param[in] y input raft::device_matrix_view + * @returns dense matrix descriptor to be used by cuSparse API + */ +template +bool is_row_major(raft::device_matrix_view& y, + raft::device_matrix_view& z) +{ + bool is_row_major = z.stride(1) == 1 && y.stride(1) == 1; + bool is_col_major = z.stride(0) == 1 && y.stride(0) == 1; + ASSERT(is_row_major || is_col_major, "Both matrices need to be either row or col major"); + return is_row_major; +} + +/** + * @brief create a cuSparse dense descriptor + * @tparam ValueType Data type of dense_view (float/double) + * @tparam IndexType Type of dense_view + * @tparam LayoutPolicy layout of dense_view + * @param[in] dense_view input raft::device_matrix_view + * @param[in] is_row_major data layout of raft::device_matrix_view + * @returns dense matrix descriptor to be used by cuSparse API + */ +template +cusparseDnMatDescr_t create_descriptor( + raft::device_matrix_view& dense_view, const bool is_row_major) +{ + auto order = is_row_major ? CUSPARSE_ORDER_ROW : CUSPARSE_ORDER_COL; + IndexType ld = is_row_major ? dense_view.stride(0) : dense_view.stride(1); + cusparseDnMatDescr_t descr; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatednmat( + &descr, + dense_view.extent(0), + dense_view.extent(1), + ld, + const_cast*>(dense_view.data_handle()), + order)); + return descr; +} + +/** + * @brief create a cuSparse sparse descriptor + * @tparam ValueType Data type of sparse_view (float/double) + * @tparam NZType Type of sparse_view + * @param[in] sparse_view input raft::device_csr_matrix_view of size M rows x K columns + * @returns sparse matrix descriptor to be used by cuSparse API + */ +template +cusparseSpMatDescr_t create_descriptor( + raft::device_csr_matrix_view& sparse_view) +{ + cusparseSpMatDescr_t descr; + auto csr_structure = sparse_view.structure_view(); + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsecreatecsr( + &descr, + csr_structure.get_n_rows(), + csr_structure.get_n_cols(), + csr_structure.get_nnz(), + const_cast(csr_structure.get_indptr().data()), + const_cast(csr_structure.get_indices().data()), + const_cast*>(sparse_view.get_elements().data()))); + return descr; +} + +/** + * @brief SPMM function designed for handling all CSR * DENSE + * combinations of operand layouts for cuSparse. + * It computes the following equation: Z = alpha . X * Y + beta . Z + * where X is a CSR device matrix view and Y,Z are device matrix views + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of Y and Z + * @tparam NZType Type of X + * @tparam LayoutPolicyY layout of Y + * @tparam LayoutPolicyZ layout of Z + * @param[in] handle raft handle + * @param[in] trans_x transpose operation for X + * @param[in] trans_y transpose operation for Y + * @param[in] is_row_major data layout of Y,Z + * @param[in] alpha scalar + * @param[in] descr_x input sparse descriptor + * @param[in] descr_y input dense descriptor + * @param[in] beta scalar + * @param[out] descr_z output dense descriptor + */ +template +void spmm(raft::device_resources const& handle, + const bool trans_x, + const bool trans_y, + const bool is_row_major, + const ValueType* alpha, + cusparseSpMatDescr_t& descr_x, + cusparseDnMatDescr_t& descr_y, + const ValueType* beta, + cusparseDnMatDescr_t& descr_z) +{ + auto opX = trans_x ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; + auto opY = trans_y ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE; + auto alg = is_row_major ? CUSPARSE_SPMM_CSR_ALG2 : CUSPARSE_SPMM_CSR_ALG1; + size_t bufferSize; + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm_bufferSize(handle.get_cusparse_handle(), + opX, + opY, + alpha, + descr_x, + descr_y, + beta, + descr_z, + alg, + &bufferSize, + handle.get_stream())); + + raft::interruptible::synchronize(handle.get_stream()); + + rmm::device_uvector tmp(bufferSize, handle.get_stream()); + + RAFT_CUSPARSE_TRY(raft::sparse::detail::cusparsespmm(handle.get_cusparse_handle(), + opX, + opY, + alpha, + descr_x, + descr_y, + beta, + descr_z, + alg, + tmp.data(), + handle.get_stream())); +} + +} // end namespace detail +} // end namespace linalg +} // end namespace sparse +} // end namespace raft diff --git a/cpp/include/raft/sparse/linalg/norm.cuh b/cpp/include/raft/sparse/linalg/norm.cuh index e13fd22843..2bd48c6dc6 100644 --- a/cpp/include/raft/sparse/linalg/norm.cuh +++ b/cpp/include/raft/sparse/linalg/norm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ #pragma once +#include #include namespace raft { @@ -66,6 +67,38 @@ void csr_row_normalize_max(const int* ia, // csr row ind array (sorted by row) detail::csr_row_normalize_max(ia, vals, nnz, m, result, stream); } +/** + * @brief Compute row-wise norm of the input matrix and perform fin_op lambda + * + * Row-wise norm is useful while computing pairwise distance matrix, for + * example. + * This is used in many clustering algos like knn, kmeans, dbscan, etc... + * + * @tparam Type the data type + * @tparam Lambda device final lambda + * @tparam IdxType Integer type used to for addressing + * @param handle raft handle + * @param ia the input matrix row index array + * @param data the input matrix nnz data + * @param nnz number of elements in data + * @param N number of rows + * @param norm the output vector of row-wise norm, size [N] + * @param type the type of norm to be applied + * @param fin_op the final lambda op + */ +template +void rowNormCsr(raft::device_resources const& handle, + const IdxType* ia, + const Type* data, + const IdxType nnz, + const IdxType N, + Type* norm, + raft::linalg::NormType type, + Lambda fin_op = raft::identity_op()) +{ + detail::rowNormCsrCaller(ia, data, nnz, N, norm, type, fin_op, handle.get_stream()); +} + }; // end NAMESPACE linalg }; // end NAMESPACE sparse }; // end NAMESPACE raft diff --git a/cpp/include/raft/sparse/linalg/spmm.cuh b/cpp/include/raft/sparse/linalg/spmm.cuh new file mode 100644 index 0000000000..73170cfc70 --- /dev/null +++ b/cpp/include/raft/sparse/linalg/spmm.cuh @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __SPMM_H +#define __SPMM_H + +#pragma once + +#include "detail/spmm.hpp" + +namespace raft { +namespace sparse { +namespace linalg { + +/** + * @brief SPMM function designed for handling all CSR * DENSE + * combinations of operand layouts for cuSparse. + * It computes the following equation: Z = alpha . X * Y + beta . Z + * where X is a CSR device matrix view and Y,Z are device matrix views + * @tparam ValueType Data type of input/output matrices (float/double) + * @tparam IndexType Type of Y and Z + * @tparam NZType Type of X + * @tparam LayoutPolicyY layout of Y + * @tparam LayoutPolicyZ layout of Z + * @param[in] handle raft handle + * @param[in] trans_x transpose operation for X + * @param[in] trans_y transpose operation for Y + * @param[in] alpha scalar + * @param[in] x input raft::device_csr_matrix_view + * @param[in] y input raft::device_matrix_view + * @param[in] beta scalar + * @param[out] z output raft::device_matrix_view + */ +template +void spmm(raft::device_resources const& handle, + const bool trans_x, + const bool trans_y, + const ValueType* alpha, + raft::device_csr_matrix_view x, + raft::device_matrix_view y, + const ValueType* beta, + raft::device_matrix_view z) +{ + bool is_row_major = detail::is_row_major(y, z); + + auto descr_x = detail::create_descriptor(x); + auto descr_y = detail::create_descriptor(y, is_row_major); + auto descr_z = detail::create_descriptor(z, is_row_major); + + detail::spmm(handle, trans_x, trans_y, is_row_major, alpha, descr_x, descr_y, beta, descr_z); + + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroySpMat(descr_x)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_y)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroyDnMat(descr_z)); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // end namespace linalg +} // end namespace sparse +} // end namespace raft + +#endif diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 22e8a9d73c..c8d4f91ec0 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -236,6 +236,7 @@ if(BUILD_TESTS) test/sparse/degree.cu test/sparse/filter.cu test/sparse/norm.cu + test/sparse/normalize.cu test/sparse/reduce.cu test/sparse/row_op.cu test/sparse/sort.cu @@ -244,7 +245,14 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME SPARSE_DIST_TEST PATH test/sparse/dist_coo_spmv.cu test/sparse/distance.cu OPTIONAL LIB + NAME + SPARSE_DIST_TEST + PATH + test/sparse/dist_coo_spmv.cu + test/sparse/distance.cu + test/sparse/gram.cu + OPTIONAL + LIB ) ConfigureTest( diff --git a/cpp/test/distance/gram.cu b/cpp/test/distance/gram.cu index f99d02dc7f..47da201465 100644 --- a/cpp/test/distance/gram.cu +++ b/cpp/test/distance/gram.cu @@ -19,6 +19,7 @@ #endif #include "../test_utils.cuh" +#include "gram_base.cuh" #include #include #include @@ -31,12 +32,6 @@ namespace raft::distance::kernels { -// Get the offset of element [i,k]. -HDI int get_offset(int i, int k, int ld, bool is_row_major) -{ - return is_row_major ? i * ld + k : i + k * ld; -} - struct GramMatrixInputs { int n1; // feature vectors in matrix 1 int n2; // featuer vectors in matrix 2 @@ -110,62 +105,46 @@ class GramMatrixTest : public ::testing::TestWithParam { ~GramMatrixTest() override { RAFT_CUDA_TRY_NO_THROW(cudaStreamDestroy(stream)); } - // Calculate the Gram matrix on the host. - void naiveKernel() - { - std::vector x1_host(x1.size()); - raft::update_host(x1_host.data(), x1.data(), x1.size(), stream); - std::vector x2_host(x2.size()); - raft::update_host(x2_host.data(), x2.data(), x2.size(), stream); - handle.sync_stream(stream); - - for (int i = 0; i < params.n1; i++) { - for (int j = 0; j < params.n2; j++) { - float d = 0; - for (int k = 0; k < params.n_cols; k++) { - if (params.kernel.kernel == KernelType::RBF) { - math_t diff = x1_host[get_offset(i, k, params.ld1, params.is_row_major)] - - x2_host[get_offset(j, k, params.ld2, params.is_row_major)]; - d += diff * diff; - } else { - d += x1_host[get_offset(i, k, params.ld1, params.is_row_major)] * - x2_host[get_offset(j, k, params.ld2, params.is_row_major)]; - } - } - int idx = get_offset(i, j, params.ld_out, params.is_row_major); - math_t v = 0; - switch (params.kernel.kernel) { - case (KernelType::LINEAR): gram_host[idx] = d; break; - case (KernelType::POLYNOMIAL): - v = params.kernel.gamma * d + params.kernel.coef0; - gram_host[idx] = std::pow(v, params.kernel.degree); - break; - case (KernelType::TANH): - gram_host[idx] = std::tanh(params.kernel.gamma * d + params.kernel.coef0); - break; - case (KernelType::RBF): gram_host[idx] = exp(-params.kernel.gamma * d); break; - } - } - } - } - void runTest() { - std::unique_ptr> kernel = std::unique_ptr>( - KernelFactory::create(params.kernel, handle.get_cublas_handle())); - - kernel->evaluate(x1.data(), - params.n1, - params.n_cols, - x2.data(), - params.n2, - gram.data(), - params.is_row_major, - stream, - params.ld1, - params.ld2, - params.ld_out); - naiveKernel(); + std::unique_ptr> kernel = + std::unique_ptr>(KernelFactory::create(params.kernel)); + + auto x1_span = + params.is_row_major + ? raft::make_device_strided_matrix_view( + x1.data(), params.n1, params.n_cols, params.ld1) + : raft::make_device_strided_matrix_view( + x1.data(), params.n1, params.n_cols, params.ld1); + auto x2_span = + params.is_row_major + ? raft::make_device_strided_matrix_view( + x2.data(), params.n2, params.n_cols, params.ld2) + : raft::make_device_strided_matrix_view( + x2.data(), params.n2, params.n_cols, params.ld2); + auto out_span = + params.is_row_major + ? raft::make_device_strided_matrix_view( + gram.data(), params.n1, params.n2, params.ld_out) + : raft::make_device_strided_matrix_view( + gram.data(), params.n1, params.n2, params.ld_out); + + (*kernel)(handle, x1_span, x2_span, out_span); + + naiveGramMatrixKernel(params.n1, + params.n2, + params.n_cols, + x1, + x2, + gram_host.data(), + params.ld1, + params.ld2, + params.ld_out, + params.is_row_major, + params.kernel, + stream, + handle); + ASSERT_TRUE(raft::devArrMatchHost( gram_host.data(), gram.data(), gram.size(), raft::CompareApprox(1e-6f))); } diff --git a/cpp/test/distance/gram_base.cuh b/cpp/test/distance/gram_base.cuh new file mode 100644 index 0000000000..8c0652bc16 --- /dev/null +++ b/cpp/test/distance/gram_base.cuh @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { +namespace kernels { + +// Get the offset of element [i,k]. +HDI int get_offset(int i, int k, int ld, bool is_row_major) +{ + return is_row_major ? i * ld + k : i + k * ld; +} + +// Calculate the Gram matrix on the host. +template +void naiveGramMatrixKernel(int n1, + int n2, + int n_cols, + const rmm::device_uvector& x1, + const rmm::device_uvector& x2, + math_t* gram_host, + int ld1, + int ld2, + int ld_out, + bool is_row_major, + KernelParams kernel, + cudaStream_t stream, + const raft::device_resources& handle) +{ + std::vector x1_host(x1.size()); + raft::update_host(x1_host.data(), x1.data(), x1.size(), stream); + std::vector x2_host(x2.size()); + raft::update_host(x2_host.data(), x2.data(), x2.size(), stream); + handle.sync_stream(stream); + + for (int i = 0; i < n1; i++) { + for (int j = 0; j < n2; j++) { + float d = 0; + for (int k = 0; k < n_cols; k++) { + if (kernel.kernel == KernelType::RBF) { + math_t diff = x1_host[get_offset(i, k, ld1, is_row_major)] - + x2_host[get_offset(j, k, ld2, is_row_major)]; + d += diff * diff; + } else { + d += x1_host[get_offset(i, k, ld1, is_row_major)] * + x2_host[get_offset(j, k, ld2, is_row_major)]; + } + } + int idx = get_offset(i, j, ld_out, is_row_major); + math_t v = 0; + switch (kernel.kernel) { + case (KernelType::LINEAR): gram_host[idx] = d; break; + case (KernelType::POLYNOMIAL): + v = kernel.gamma * d + kernel.coef0; + gram_host[idx] = std::pow(v, kernel.degree); + break; + case (KernelType::TANH): gram_host[idx] = std::tanh(kernel.gamma * d + kernel.coef0); break; + case (KernelType::RBF): gram_host[idx] = exp(-kernel.gamma * d); break; + } + } + } +} + +} // namespace kernels +} // namespace distance +} // namespace raft diff --git a/cpp/test/sparse/gram.cu b/cpp/test/sparse/gram.cu new file mode 100644 index 0000000000..86a2e0cf43 --- /dev/null +++ b/cpp/test/sparse/gram.cu @@ -0,0 +1,327 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined RAFT_DISTANCE_COMPILED +#include +#endif + +#include "../distance/gram_base.cuh" +#include "../test_utils.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::distance::kernels { + +/** + * Structure to describe structure of the input matrices: + * - DENSE: dense, dense + * - MIX: CSR, dense + * - CSR: CSR, CSR + */ +enum SparseType { DENSE, MIX, CSR }; + +struct GramMatrixInputs { + int n1; // feature vectors in matrix 1 + int n2; // featuer vectors in matrix 2 + int n_cols; // number of elements in a feature vector + bool is_row_major; + SparseType sparse_input; + KernelParams kernel; + int ld1; + int ld2; + int ld_out; + // We will generate random input using the dimensions given here. + // The reference output is calculated by a custom kernel. +}; + +std::ostream& operator<<(std::ostream& os, const GramMatrixInputs& p) +{ + std::vector kernel_names{"linear", "poly", "rbf", "tanh"}; + os << "/" << p.n1 << "x" << p.n2 << "x" << p.n_cols << "/" + << (p.is_row_major ? "RowMajor/" : "ColMajor/") + << (p.sparse_input == SparseType::DENSE + ? "DenseDense/" + : (p.sparse_input == SparseType::MIX ? "CsrDense/" : "CsrCsr/")) + << kernel_names[p.kernel.kernel] << "/ld_" << p.ld1 << "x" << p.ld2 << "x" << p.ld_out; + return os; +} + +/*struct KernelParams { + // Kernel function parameters + KernelType kernel; //!< Type of the kernel function + int degree; //!< Degree of polynomial kernel (ignored by others) + double gamma; //!< multiplier in the + double coef0; //!< additive constant in poly and tanh kernels +};*/ + +// const KernelParams linear_kernel_params{.kernel=KernelType::LINEAR}; + +// {KernelType::POLYNOMIAL, 2, 0.5, 2.4}, {KernelType::TANH, 0, 0.5, 2.4}, {KernelType::RBF, 0, 0.5} +const std::vector inputs = raft::util::itertools::product( + {42}, + {137}, + {2}, + {true, false}, + {SparseType::DENSE, SparseType::MIX, SparseType::CSR}, + {KernelParams{KernelType::LINEAR}, + KernelParams{KernelType::POLYNOMIAL, 2, 0.5, 2.4}, + KernelParams{KernelType::TANH, 0, 0.5, 2.4}, + KernelParams{KernelType::RBF, 0, 0.5}}); + +// (ld_1, ld_2, ld_out) not supported by RBF and CSR +const std::vector inputs_ld = raft::util::itertools::product( + {137}, + {42}, + {2}, + {true, false}, + {SparseType::DENSE, SparseType::MIX}, + {KernelParams{KernelType::LINEAR}, + KernelParams{KernelType::POLYNOMIAL, 2, 0.5, 2.4}, + KernelParams{KernelType::TANH, 0, 0.5, 2.4}}, + {159}, + {73}, + {144}); + +// (ld_1, ld_2) are supported by CSR +const std::vector inputs_ld_csr = + raft::util::itertools::product( + {42}, + {137}, + {2}, + {true, false}, + {SparseType::CSR, SparseType::MIX}, + {KernelParams{KernelType::LINEAR}, + KernelParams{KernelType::POLYNOMIAL, 2, 0.5, 2.4}, + KernelParams{KernelType::TANH, 0, 0.5, 2.4}}, + {64}, + {155}, + {0}); + +template +class GramMatrixTest : public ::testing::TestWithParam { + protected: + GramMatrixTest() + : params(GetParam()), + stream(0), + x1(0, stream), + x2(0, stream), + x1_csr_indptr(0, stream), + x1_csr_indices(0, stream), + x1_csr_data(0, stream), + x2_csr_indptr(0, stream), + x2_csr_indices(0, stream), + x2_csr_data(0, stream), + gram(0, stream), + gram_host(0) + { + RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + + if (params.ld1 == 0) { params.ld1 = params.is_row_major ? params.n_cols : params.n1; } + if (params.ld2 == 0) { params.ld2 = params.is_row_major ? params.n_cols : params.n2; } + if (params.ld_out == 0) { params.ld_out = params.is_row_major ? params.n2 : params.n1; } + // Derive the size of the output from the offset of the last element. + size_t size = get_offset(params.n1 - 1, params.n_cols - 1, params.ld1, params.is_row_major) + 1; + x1.resize(size, stream); + size = get_offset(params.n2 - 1, params.n_cols - 1, params.ld2, params.is_row_major) + 1; + x2.resize(size, stream); + size = get_offset(params.n1 - 1, params.n2 - 1, params.ld_out, params.is_row_major) + 1; + + gram.resize(size, stream); + RAFT_CUDA_TRY(cudaMemsetAsync(gram.data(), 0, gram.size() * sizeof(math_t), stream)); + gram_host.resize(gram.size()); + std::fill(gram_host.begin(), gram_host.end(), 0); + + raft::random::Rng r(42137ULL); + r.uniform(x1.data(), x1.size(), math_t(0), math_t(1), stream); + r.uniform(x2.data(), x2.size(), math_t(0), math_t(1), stream); + } + + ~GramMatrixTest() override { RAFT_CUDA_TRY_NO_THROW(cudaStreamDestroy(stream)); } + + int prepareCsr(math_t* dense, int n_rows, int ld, int* indptr, int* indices, math_t* data) + { + int nnz = 0; + double eps = 1e-6; + int n_cols = params.n_cols; + bool is_row_major = params.is_row_major; + size_t dense_size = get_offset(n_rows - 1, n_cols - 1, ld, is_row_major) + 1; + + std::vector dense_host(dense_size); + raft::update_host(dense_host.data(), dense, dense_size, stream); + handle.sync_stream(stream); + + std::vector indptr_host(n_rows + 1); + std::vector indices_host(n_rows * n_cols); + std::vector data_host(n_rows * n_cols); + + // create csr matrix from dense (with threshold) + for (int i = 0; i < n_rows; ++i) { + indptr_host[i] = nnz; + for (int j = 0; j < n_cols; ++j) { + math_t value = dense_host[get_offset(i, j, ld, is_row_major)]; + if (value > eps) { + indices_host[nnz] = j; + data_host[nnz] = value; + nnz++; + } + } + } + indptr_host[n_rows] = nnz; + + // fill back dense matrix from CSR + std::fill(dense_host.data(), dense_host.data() + dense_size, 0); + for (int i = 0; i < n_rows; ++i) { + for (int idx = indptr_host[i]; idx < indptr_host[i + 1]; ++idx) { + dense_host[get_offset(i, indices_host[idx], ld, is_row_major)] = data_host[idx]; + } + } + + raft::update_device(dense, dense_host.data(), dense_size, stream); + raft::update_device(indptr, indptr_host.data(), n_rows + 1, stream); + raft::update_device(indices, indices_host.data(), nnz, stream); + raft::update_device(data, data_host.data(), nnz, stream); + handle.sync_stream(stream); + + return nnz; + } + + void runTest() + { + std::unique_ptr> kernel = + std::unique_ptr>(KernelFactory::create(params.kernel)); + + auto x1_span = + params.is_row_major + ? raft::make_device_strided_matrix_view( + x1.data(), params.n1, params.n_cols, params.ld1) + : raft::make_device_strided_matrix_view( + x1.data(), params.n1, params.n_cols, params.ld1); + auto x2_span = + params.is_row_major + ? raft::make_device_strided_matrix_view( + x2.data(), params.n2, params.n_cols, params.ld2) + : raft::make_device_strided_matrix_view( + x2.data(), params.n2, params.n_cols, params.ld2); + auto out_span = + params.is_row_major + ? raft::make_device_strided_matrix_view( + gram.data(), params.n1, params.n2, params.ld_out) + : raft::make_device_strided_matrix_view( + gram.data(), params.n1, params.n2, params.ld_out); + + if (params.sparse_input == SparseType::DENSE) { + (*kernel)(handle, x1_span, x2_span, out_span); + } else { + x1_csr_indptr.reserve(params.n1 + 1, stream); + x1_csr_indices.reserve(params.n1 * params.n_cols, stream); + x1_csr_data.reserve(params.n1 * params.n_cols, stream); + int x1_nnz = prepareCsr(x1.data(), + params.n1, + params.ld1, + x1_csr_indptr.data(), + x1_csr_indices.data(), + x1_csr_data.data()); + + auto x1_csr_structure = raft::make_device_compressed_structure_view( + x1_csr_indptr.data(), x1_csr_indices.data(), params.n1, params.n_cols, x1_nnz); + auto x1_csr = raft::device_csr_matrix_view( + raft::device_span(x1_csr_data.data(), x1_csr_structure.get_nnz()), + x1_csr_structure); + + if (params.sparse_input == SparseType::MIX) { + (*kernel)(handle, x1_csr, x2_span, out_span); + } else { + x2_csr_indptr.reserve(params.n2 + 1, stream); + x2_csr_indices.reserve(params.n2 * params.n_cols, stream); + x2_csr_data.reserve(params.n2 * params.n_cols, stream); + int x2_nnz = prepareCsr(x2.data(), + params.n2, + params.ld2, + x2_csr_indptr.data(), + x2_csr_indices.data(), + x2_csr_data.data()); + + auto x2_csr_structure = raft::make_device_compressed_structure_view( + x2_csr_indptr.data(), x2_csr_indices.data(), params.n2, params.n_cols, x2_nnz); + auto x2_csr = raft::device_csr_matrix_view( + raft::device_span(x2_csr_data.data(), x2_csr_structure.get_nnz()), + x2_csr_structure); + + (*kernel)(handle, x1_csr, x2_csr, out_span); + } + } + + naiveGramMatrixKernel(params.n1, + params.n2, + params.n_cols, + x1, + x2, + gram_host.data(), + params.ld1, + params.ld2, + params.ld_out, + params.is_row_major, + params.kernel, + stream, + handle); + + handle.sync_stream(stream); + + ASSERT_TRUE(raft::devArrMatchHost( + gram_host.data(), gram.data(), gram.size(), raft::CompareApprox(1e-6f))); + } + + raft::device_resources handle; + cudaStream_t stream = 0; + GramMatrixInputs params; + + rmm::device_uvector x1; + rmm::device_uvector x2; + + rmm::device_uvector x1_csr_indptr; + rmm::device_uvector x1_csr_indices; + rmm::device_uvector x1_csr_data; + rmm::device_uvector x2_csr_indptr; + rmm::device_uvector x2_csr_indices; + rmm::device_uvector x2_csr_data; + + rmm::device_uvector gram; + std::vector gram_host; +}; + +typedef GramMatrixTest GramMatrixTestFloatStandard; +typedef GramMatrixTest GramMatrixTestFloatLd; +typedef GramMatrixTest GramMatrixTestFloatLdCsr; + +TEST_P(GramMatrixTestFloatStandard, Gram) { runTest(); } +TEST_P(GramMatrixTestFloatLd, Gram) { runTest(); } +TEST_P(GramMatrixTestFloatLdCsr, Gram) { runTest(); } + +INSTANTIATE_TEST_SUITE_P(GramMatrixTests, GramMatrixTestFloatStandard, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_SUITE_P(GramMatrixTests, GramMatrixTestFloatLd, ::testing::ValuesIn(inputs_ld)); +INSTANTIATE_TEST_SUITE_P(GramMatrixTests, + GramMatrixTestFloatLdCsr, + ::testing::ValuesIn(inputs_ld_csr)); +}; // end namespace raft::distance::kernels diff --git a/cpp/test/sparse/norm.cu b/cpp/test/sparse/norm.cu index 91b7b09fcc..65d857652c 100644 --- a/cpp/test/sparse/norm.cu +++ b/cpp/test/sparse/norm.cu @@ -19,7 +19,7 @@ #include "../test_utils.cuh" #include -#include +#include #include #include @@ -29,26 +29,24 @@ namespace raft { namespace sparse { -enum NormalizeMethod { MAX, L1 }; - template -struct CSRRowNormalizeInputs { - NormalizeMethod method; - std::vector ex_scan; - std::vector in_vals; +struct CSRRowNormInputs { + raft::linalg::NormType norm; + std::vector indptr; + std::vector data; std::vector verify; }; template -class CSRRowNormalizeTest : public ::testing::TestWithParam> { +class CSRRowNormTest : public ::testing::TestWithParam> { public: - CSRRowNormalizeTest() - : params(::testing::TestWithParam>::GetParam()), + CSRRowNormTest() + : params(::testing::TestWithParam>::GetParam()), stream(handle.get_stream()), - in_vals(params.in_vals.size(), stream), - verify(params.verify.size(), stream), - ex_scan(params.ex_scan.size(), stream), - result(params.verify.size(), stream) + data(params.data.size(), stream), + verify(params.indptr.size() - 1, stream), + indptr(params.indptr.size(), stream), + result(params.indptr.size() - 1, stream) { } @@ -57,71 +55,66 @@ class CSRRowNormalizeTest : public ::testing::TestWithParam( - ex_scan.data(), in_vals.data(), nnz, n_rows, result.data(), stream); - break; - case L1: - linalg::csr_row_normalize_l1( - ex_scan.data(), in_vals.data(), nnz, n_rows, result.data(), stream); - break; - } + Index_ n_rows = params.indptr.size() - 1; + Index_ nnz = params.data.size(); + + raft::update_device(indptr.data(), params.indptr.data(), n_rows + 1, stream); + raft::update_device(data.data(), params.data.data(), nnz, stream); + raft::update_device(verify.data(), params.verify.data(), n_rows, stream); + + linalg::rowNormCsr(handle, indptr.data(), data.data(), nnz, n_rows, result.data(), params.norm); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); ASSERT_TRUE( - raft::devArrMatch(verify.data(), result.data(), nnz, raft::Compare())); + raft::devArrMatch(verify.data(), result.data(), n_rows, raft::Compare())); } protected: raft::device_resources handle; cudaStream_t stream; - CSRRowNormalizeInputs params; - rmm::device_uvector ex_scan; - rmm::device_uvector in_vals, result, verify; + CSRRowNormInputs params; + rmm::device_uvector indptr; + rmm::device_uvector data, result, verify; }; -using CSRRowNormalizeTestF = CSRRowNormalizeTest; -TEST_P(CSRRowNormalizeTestF, Result) { Run(); } - -using CSRRowNormalizeTestD = CSRRowNormalizeTest; -TEST_P(CSRRowNormalizeTestD, Result) { Run(); } - -const std::vector> csrnormalize_inputs_f = { - {MAX, - {0, 4, 8, 9}, - {5.0, 1.0, 0.0, 0.0, 10.0, 1.0, 0.0, 0.0, 1.0, 0.0}, - {1.0, 0.2, 0.0, 0.0, 1.0, 0.1, 0.0, 0.0, 1, 0.0}}, - {L1, - {0, 4, 8, 9}, - {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0}, - {0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 1, 0.0}}, +using CSRRowNormTestF = CSRRowNormTest; +TEST_P(CSRRowNormTestF, Result) { Run(); } + +using CSRRowNormTestD = CSRRowNormTest; +TEST_P(CSRRowNormTestD, Result) { Run(); } + +const std::vector> csrnorm_inputs_f = { + {raft::linalg::NormType::LinfNorm, + {0, 3, 7, 10}, + {5.0, 1.0, 2.0, 0.0, 10.0, 1.0, 2.0, 1.0, 1.0, 2.0}, + {5.0, 10.0, 2.0}}, + {raft::linalg::NormType::L1Norm, + {0, 3, 7, 10}, + {5.0, 1.0, 2.0, 0.0, 10.0, 1.0, 2.0, 1.0, 1.0, 2.0}, + {8.0, 13.0, 4.0}}, + {raft::linalg::NormType::L2Norm, + {0, 3, 7, 10}, + {5.0, 1.0, 2.0, 0.0, 10.0, 1.0, 2.0, 1.0, 1.0, 2.0}, + {30.0, 105.0, 6.0}}, }; -const std::vector> csrnormalize_inputs_d = { - {MAX, - {0, 4, 8, 9}, - {5.0, 1.0, 0.0, 0.0, 10.0, 1.0, 0.0, 0.0, 1.0, 0.0}, - {1.0, 0.2, 0.0, 0.0, 1.0, 0.1, 0.0, 0.0, 1, 0.0}}, - {L1, - {0, 4, 8, 9}, - {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0}, - {0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 1, 0.0}}, +const std::vector> csrnorm_inputs_d = { + {raft::linalg::NormType::LinfNorm, + {0, 3, 7, 10}, + {5.0, 1.0, 2.0, 0.0, 10.0, 1.0, 2.0, 1.0, 1.0, 2.0}, + {5.0, 10.0, 2.0}}, + {raft::linalg::NormType::L1Norm, + {0, 3, 7, 10}, + {5.0, 1.0, 2.0, 0.0, 10.0, 1.0, 2.0, 1.0, 1.0, 2.0}, + {8.0, 13.0, 4.0}}, + {raft::linalg::NormType::L2Norm, + {0, 3, 7, 10}, + {5.0, 1.0, 2.0, 0.0, 10.0, 1.0, 2.0, 1.0, 1.0, 2.0}, + {30.0, 105.0, 6.0}}, }; -INSTANTIATE_TEST_CASE_P(SparseNormTest, - CSRRowNormalizeTestF, - ::testing::ValuesIn(csrnormalize_inputs_f)); -INSTANTIATE_TEST_CASE_P(SparseNormTest, - CSRRowNormalizeTestD, - ::testing::ValuesIn(csrnormalize_inputs_d)); +INSTANTIATE_TEST_CASE_P(SparseNormTest, CSRRowNormTestF, ::testing::ValuesIn(csrnorm_inputs_f)); +INSTANTIATE_TEST_CASE_P(SparseNormTest, CSRRowNormTestD, ::testing::ValuesIn(csrnorm_inputs_d)); } // namespace sparse } // namespace raft diff --git a/cpp/test/sparse/normalize.cu b/cpp/test/sparse/normalize.cu new file mode 100644 index 0000000000..91b7b09fcc --- /dev/null +++ b/cpp/test/sparse/normalize.cu @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../test_utils.cuh" + +#include +#include +#include +#include + +#include +#include + +namespace raft { +namespace sparse { + +enum NormalizeMethod { MAX, L1 }; + +template +struct CSRRowNormalizeInputs { + NormalizeMethod method; + std::vector ex_scan; + std::vector in_vals; + std::vector verify; +}; + +template +class CSRRowNormalizeTest : public ::testing::TestWithParam> { + public: + CSRRowNormalizeTest() + : params(::testing::TestWithParam>::GetParam()), + stream(handle.get_stream()), + in_vals(params.in_vals.size(), stream), + verify(params.verify.size(), stream), + ex_scan(params.ex_scan.size(), stream), + result(params.verify.size(), stream) + { + } + + protected: + void SetUp() override {} + + void Run() + { + Index_ n_rows = params.ex_scan.size(); + Index_ nnz = params.in_vals.size(); + + raft::update_device(ex_scan.data(), params.ex_scan.data(), n_rows, stream); + raft::update_device(in_vals.data(), params.in_vals.data(), nnz, stream); + raft::update_device(verify.data(), params.verify.data(), nnz, stream); + + switch (params.method) { + case MAX: + linalg::csr_row_normalize_max( + ex_scan.data(), in_vals.data(), nnz, n_rows, result.data(), stream); + break; + case L1: + linalg::csr_row_normalize_l1( + ex_scan.data(), in_vals.data(), nnz, n_rows, result.data(), stream); + break; + } + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + + ASSERT_TRUE( + raft::devArrMatch(verify.data(), result.data(), nnz, raft::Compare())); + } + + protected: + raft::device_resources handle; + cudaStream_t stream; + + CSRRowNormalizeInputs params; + rmm::device_uvector ex_scan; + rmm::device_uvector in_vals, result, verify; +}; + +using CSRRowNormalizeTestF = CSRRowNormalizeTest; +TEST_P(CSRRowNormalizeTestF, Result) { Run(); } + +using CSRRowNormalizeTestD = CSRRowNormalizeTest; +TEST_P(CSRRowNormalizeTestD, Result) { Run(); } + +const std::vector> csrnormalize_inputs_f = { + {MAX, + {0, 4, 8, 9}, + {5.0, 1.0, 0.0, 0.0, 10.0, 1.0, 0.0, 0.0, 1.0, 0.0}, + {1.0, 0.2, 0.0, 0.0, 1.0, 0.1, 0.0, 0.0, 1, 0.0}}, + {L1, + {0, 4, 8, 9}, + {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0}, + {0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 1, 0.0}}, +}; +const std::vector> csrnormalize_inputs_d = { + {MAX, + {0, 4, 8, 9}, + {5.0, 1.0, 0.0, 0.0, 10.0, 1.0, 0.0, 0.0, 1.0, 0.0}, + {1.0, 0.2, 0.0, 0.0, 1.0, 0.1, 0.0, 0.0, 1, 0.0}}, + {L1, + {0, 4, 8, 9}, + {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0}, + {0.5, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 1, 0.0}}, +}; + +INSTANTIATE_TEST_CASE_P(SparseNormTest, + CSRRowNormalizeTestF, + ::testing::ValuesIn(csrnormalize_inputs_f)); +INSTANTIATE_TEST_CASE_P(SparseNormTest, + CSRRowNormalizeTestD, + ::testing::ValuesIn(csrnormalize_inputs_d)); + +} // namespace sparse +} // namespace raft