From 7deaa020a39659935b1958a5d5a0a2fe9f79bb41 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 8 May 2023 16:41:02 -0700 Subject: [PATCH 01/11] Remove raft/matrix/matrix.cuh includes The `raft/matrix/matrix.cuh` file has been marked as deprecated, and produces a compile warning when included. However it was still being referenced in a bunch of different spots within raft - making it hard to avoid these warnings. Remove the includes, in favour of either the newer API's or in certain cases the detail API --- .../raft/cluster/detail/kmeans_balanced.cuh | 1 - cpp/include/raft/linalg/detail/eig.cuh | 16 ++++++++-------- cpp/include/raft/linalg/detail/lstsq.cuh | 1 - cpp/include/raft/linalg/detail/qr.cuh | 4 ++-- cpp/include/raft/linalg/detail/rsvd.cuh | 7 ++++--- cpp/include/raft/linalg/detail/svd.cuh | 16 ++++++++-------- cpp/include/raft/matrix/copy.cuh | 2 +- cpp/include/raft/matrix/diagonal.cuh | 3 +-- cpp/include/raft/matrix/init.cuh | 1 - cpp/include/raft/matrix/linewise_op.cuh | 1 - cpp/include/raft/matrix/print.cuh | 1 - cpp/include/raft/matrix/sign_flip.cuh | 1 - cpp/include/raft/matrix/slice.cuh | 2 +- cpp/include/raft/matrix/sqrt.cuh | 1 - cpp/include/raft/matrix/triangular.cuh | 2 +- .../raft/neighbors/detail/ivf_pq_build.cuh | 8 ++++++-- cpp/include/raft/neighbors/refine-inl.cuh | 1 - .../raft/random/detail/make_regression.cuh | 5 +++-- .../raft/sparse/neighbors/detail/knn.cuh | 1 - .../raft/spatial/knn/detail/ball_cover.cuh | 18 +++++++++--------- cpp/test/linalg/svd.cu | 1 - cpp/test/matrix/matrix.cu | 2 +- cpp/test/neighbors/ann_utils.cuh | 9 ++++++--- 23 files changed, 51 insertions(+), 53 deletions(-) diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh index eb89ebe402..9e5f7a7c9a 100644 --- a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh @@ -37,7 +37,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/include/raft/linalg/detail/eig.cuh b/cpp/include/raft/linalg/detail/eig.cuh index 94493efb24..8ba9f62910 100644 --- a/cpp/include/raft/linalg/detail/eig.cuh +++ b/cpp/include/raft/linalg/detail/eig.cuh @@ -19,7 +19,7 @@ #include "cusolver_wrappers.hpp" #include #include -#include +#include #include #include #include @@ -52,7 +52,7 @@ void eigDC_legacy(raft::device_resources const& handle, rmm::device_uvector d_work(lwork, stream); rmm::device_scalar d_dev_info(stream); - raft::matrix::copy(in, eig_vectors, n_rows, n_cols, stream); + raft::copy_async(eig_vectors, in, n_rows * n_cols, stream); RAFT_CUSOLVER_TRY(cusolverDnsyevd(cusolverH, CUSOLVER_EIG_MODE_VECTOR, @@ -108,7 +108,7 @@ void eigDC(raft::device_resources const& handle, rmm::device_scalar d_dev_info(stream); std::vector h_work(workspaceHost / sizeof(math_t)); - raft::matrix::copy(in, eig_vectors, n_rows, n_cols, stream); + raft::copy_async(eig_vectors, in, n_rows * n_cols, stream); RAFT_CUSOLVER_TRY(cusolverDnxsyevd(cusolverH, dn_params, @@ -191,7 +191,7 @@ void eigSelDC(raft::device_resources const& handle, stream)); } else if (memUsage == COPY_INPUT) { d_eig_vectors.resize(n_rows * n_cols, stream); - raft::matrix::copy(in, d_eig_vectors.data(), n_rows, n_cols, stream); + raft::copy_async(d_eig_vectors.data(), in, n_rows * n_cols, stream); RAFT_CUSOLVER_TRY(cusolverDnsyevdx(cusolverH, CUSOLVER_EIG_MODE_VECTOR, @@ -220,9 +220,9 @@ void eigSelDC(raft::device_resources const& handle, "This usually occurs when some of the features do not vary enough."); if (memUsage == OVERWRITE_INPUT) { - raft::matrix::truncZeroOrigin(in, n_rows, eig_vectors, n_rows, n_eig_vals, stream); + raft::matrix::detail::truncZeroOrigin(in, n_rows, eig_vectors, n_rows, n_eig_vals, stream); } else if (memUsage == COPY_INPUT) { - raft::matrix::truncZeroOrigin( + raft::matrix::detail::truncZeroOrigin( d_eig_vectors.data(), n_rows, eig_vectors, n_rows, n_eig_vals, stream); } } @@ -259,7 +259,7 @@ void eigJacobi(raft::device_resources const& handle, rmm::device_uvector d_work(lwork, stream); rmm::device_scalar dev_info(stream); - raft::matrix::copy(in, eig_vectors, n_rows, n_cols, stream); + raft::copy_async(eig_vectors, in, n_rows * n_cols, stream); RAFT_CUSOLVER_TRY(cusolverDnsyevj(cusolverH, CUSOLVER_EIG_MODE_VECTOR, @@ -283,4 +283,4 @@ void eigJacobi(raft::device_resources const& handle, } // namespace detail } // namespace linalg -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/linalg/detail/lstsq.cuh b/cpp/include/raft/linalg/detail/lstsq.cuh index 207bcefc32..fd6b00f9fd 100644 --- a/cpp/include/raft/linalg/detail/lstsq.cuh +++ b/cpp/include/raft/linalg/detail/lstsq.cuh @@ -28,7 +28,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/include/raft/linalg/detail/qr.cuh b/cpp/include/raft/linalg/detail/qr.cuh index bc7c551d89..721c02dc48 100644 --- a/cpp/include/raft/linalg/detail/qr.cuh +++ b/cpp/include/raft/linalg/detail/qr.cuh @@ -20,7 +20,7 @@ #include "cusolver_wrappers.hpp" #include #include -#include +#include #include #include @@ -132,7 +132,7 @@ void qrGetQR(raft::resources const& handle, devInfo.data(), stream)); - raft::matrix::copyUpperTriangular(R_full.data(), R, m, n, stream); + raft::matrix::detail::copyUpperTriangular(R_full.data(), R, m, n, stream); RAFT_CUDA_TRY( cudaMemcpyAsync(Q, R_full.data(), sizeof(math_t) * m * n, cudaMemcpyDeviceToDevice, stream)); diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index a66a23179b..283b982e1f 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -21,8 +21,8 @@ #include #include #include +#include #include -#include #include #include @@ -272,7 +272,7 @@ void rsvdFixedRank(raft::device_resources const& handle, RAFT_CUDA_TRY(cudaMemsetAsync(Uhat.data(), 0, sizeof(math_t) * l * l, stream)); rmm::device_uvector Uhat_dup(l * l, stream); RAFT_CUDA_TRY(cudaMemsetAsync(Uhat_dup.data(), 0, sizeof(math_t) * l * l, stream)); - raft::matrix::copyUpperTriangular(BBt.data(), Uhat_dup.data(), l, l, stream); + raft::matrix::detail::copyUpperTriangular(BBt.data(), Uhat_dup.data(), l, l, stream); if (use_jacobi) raft::linalg::eigJacobi( handle, Uhat_dup.data(), l, l, Uhat.data(), S_vec_tmp.data(), stream, tol, max_sweeps); @@ -316,7 +316,8 @@ void rsvdFixedRank(raft::device_resources const& handle, rmm::device_uvector UhatSinv(l * k, stream); RAFT_CUDA_TRY(cudaMemsetAsync(UhatSinv.data(), 0, sizeof(math_t) * l * k, stream)); raft::matrix::reciprocal(S_vec_tmp.data(), l, stream); - raft::matrix::initializeDiagonalMatrix(S_vec_tmp.data() + p, Sinv.data(), k, k, stream); + raft::matrix::detail::initializeDiagonalMatrix( + S_vec_tmp.data() + p, Sinv.data(), k, k, stream); raft::linalg::gemm(handle, Uhat.data() + p * l, diff --git a/cpp/include/raft/linalg/detail/svd.cuh b/cpp/include/raft/linalg/detail/svd.cuh index 998bea5b1b..d07e444003 100644 --- a/cpp/include/raft/linalg/detail/svd.cuh +++ b/cpp/include/raft/linalg/detail/svd.cuh @@ -24,8 +24,8 @@ #include #include +#include #include -#include #include #include #include @@ -285,15 +285,15 @@ bool evaluateSVDByL2Norm(raft::device_resources const& handle, RAFT_CUDA_TRY(cudaMemsetAsync(P_d.data(), 0, sizeof(math_t) * m * n, stream)); RAFT_CUDA_TRY(cudaMemsetAsync(S_mat.data(), 0, sizeof(math_t) * k * k, stream)); - raft::matrix::initializeDiagonalMatrix(S_vec, S_mat.data(), k, k, stream); + raft::matrix::detail::initializeDiagonalMatrix(S_vec, S_mat.data(), k, k, stream); svdReconstruction(handle, U, S_mat.data(), V, P_d.data(), m, n, k, stream); // get norms of each - math_t normA = raft::matrix::getL2Norm(handle, A_d, m * n, stream); - math_t normU = raft::matrix::getL2Norm(handle, U, m * k, stream); - math_t normS = raft::matrix::getL2Norm(handle, S_mat.data(), k * k, stream); - math_t normV = raft::matrix::getL2Norm(handle, V, n * k, stream); - math_t normP = raft::matrix::getL2Norm(handle, P_d.data(), m * n, stream); + math_t normA = raft::matrix::detail::getL2Norm(handle, A_d, m * n, stream); + math_t normU = raft::matrix::detail::getL2Norm(handle, U, m * k, stream); + math_t normS = raft::matrix::detail::getL2Norm(handle, S_mat.data(), k * k, stream); + math_t normV = raft::matrix::detail::getL2Norm(handle, V, n * k, stream); + math_t normP = raft::matrix::detail::getL2Norm(handle, P_d.data(), m * n, stream); // calculate percent error const math_t alpha = 1.0, beta = -1.0; @@ -315,7 +315,7 @@ bool evaluateSVDByL2Norm(raft::device_resources const& handle, m, stream)); - math_t norm_A_minus_P = raft::matrix::getL2Norm(handle, A_minus_P.data(), m * n, stream); + math_t norm_A_minus_P = raft::matrix::detail::getL2Norm(handle, A_minus_P.data(), m * n, stream); math_t percent_error = 100.0 * norm_A_minus_P / normA; return (percent_error / 100.0 < tol); } diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index 42d2562e5e..f5bacc32dd 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -42,7 +42,7 @@ template void copy_rows(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, - raft::device_vector_view indices) + raft::device_vector_view indices) { RAFT_EXPECTS(in.extent(1) == out.extent(1), "Input and output matrices must have same number of columns"); diff --git a/cpp/include/raft/matrix/diagonal.cuh b/cpp/include/raft/matrix/diagonal.cuh index 22147e9f34..e0141cbf01 100644 --- a/cpp/include/raft/matrix/diagonal.cuh +++ b/cpp/include/raft/matrix/diagonal.cuh @@ -18,7 +18,6 @@ #include #include -#include namespace raft::matrix { @@ -84,4 +83,4 @@ void invert_diagonal(raft::device_resources const& handle, /** @} */ // end of group matrix_diagonal -} // namespace raft::matrix \ No newline at end of file +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/init.cuh b/cpp/include/raft/matrix/init.cuh index ed2fb4d209..9611e044f4 100644 --- a/cpp/include/raft/matrix/init.cuh +++ b/cpp/include/raft/matrix/init.cuh @@ -20,7 +20,6 @@ #include #include #include -#include namespace raft::matrix { diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh index 33de112a35..056ef4f411 100644 --- a/cpp/include/raft/matrix/linewise_op.cuh +++ b/cpp/include/raft/matrix/linewise_op.cuh @@ -18,7 +18,6 @@ #include #include -#include namespace raft::matrix { diff --git a/cpp/include/raft/matrix/print.cuh b/cpp/include/raft/matrix/print.cuh index 6a4bfbdd01..f2c2653211 100644 --- a/cpp/include/raft/matrix/print.cuh +++ b/cpp/include/raft/matrix/print.cuh @@ -19,7 +19,6 @@ #include #include #include -#include #include namespace raft::matrix { diff --git a/cpp/include/raft/matrix/sign_flip.cuh b/cpp/include/raft/matrix/sign_flip.cuh index d069c55880..93962fb67d 100644 --- a/cpp/include/raft/matrix/sign_flip.cuh +++ b/cpp/include/raft/matrix/sign_flip.cuh @@ -18,7 +18,6 @@ #include #include -#include namespace raft::matrix { diff --git a/cpp/include/raft/matrix/slice.cuh b/cpp/include/raft/matrix/slice.cuh index bb92b2b86f..071a10a847 100644 --- a/cpp/include/raft/matrix/slice.cuh +++ b/cpp/include/raft/matrix/slice.cuh @@ -76,4 +76,4 @@ void slice(raft::device_resources const& handle, /** @} */ // end group matrix_slice -} // namespace raft::matrix \ No newline at end of file +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/sqrt.cuh b/cpp/include/raft/matrix/sqrt.cuh index 9729f9b3d5..309ae3452f 100644 --- a/cpp/include/raft/matrix/sqrt.cuh +++ b/cpp/include/raft/matrix/sqrt.cuh @@ -19,7 +19,6 @@ #include #include #include -#include namespace raft::matrix { diff --git a/cpp/include/raft/matrix/triangular.cuh b/cpp/include/raft/matrix/triangular.cuh index 3c60cc362f..1388c829b0 100644 --- a/cpp/include/raft/matrix/triangular.cuh +++ b/cpp/include/raft/matrix/triangular.cuh @@ -46,4 +46,4 @@ void upper_triangular(raft::device_resources const& handle, /** @} */ // end group matrix_triangular -} // namespace raft::matrix \ No newline at end of file +} // namespace raft::matrix diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh index b17b3a3559..53d8823eea 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh @@ -181,8 +181,12 @@ void select_residuals(raft::device_resources const& handle, dataset, utils::mapping{}); raft::matrix::gather(mapping_itr, (IdxT)dim, n_rows, row_ids, n_rows, tmp.data(), stream); - raft::matrix::linewiseOp( - tmp.data(), tmp.data(), IdxT(dim), n_rows, true, raft::sub_op{}, stream, center); + raft::matrix::linewise_op(handle, + make_device_matrix_view(tmp.data(), n_rows, dim), + make_device_matrix_view(tmp.data(), n_rows, dim), + true, + raft::sub_op{}, + make_device_vector_view(center, dim)); float alpha = 1.0; float beta = 0.0; diff --git a/cpp/include/raft/neighbors/refine-inl.cuh b/cpp/include/raft/neighbors/refine-inl.cuh index 4243d7e723..2c4dfb422e 100644 --- a/cpp/include/raft/neighbors/refine-inl.cuh +++ b/cpp/include/raft/neighbors/refine-inl.cuh @@ -19,7 +19,6 @@ #include #include #include -#include #include #include diff --git a/cpp/include/raft/random/detail/make_regression.cuh b/cpp/include/raft/random/detail/make_regression.cuh index 1715dcbe81..3a236d8834 100644 --- a/cpp/include/raft/random/detail/make_regression.cuh +++ b/cpp/include/raft/random/detail/make_regression.cuh @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include #include @@ -83,7 +83,8 @@ static void _make_low_rank_matrix(raft::resources const& handle, RAFT_CUDA_TRY(cudaPeekAtLastError()); rmm::device_uvector singular_mat(n * n, stream); RAFT_CUDA_TRY(cudaMemsetAsync(singular_mat.data(), 0, n * n * sizeof(DataT), stream)); - raft::matrix::initializeDiagonalMatrix(singular_vec.data(), singular_mat.data(), n, n, stream); + raft::matrix::detail::initializeDiagonalMatrix( + singular_vec.data(), singular_mat.data(), n, n, stream); // Generate the column-major matrix rmm::device_uvector temp_q0s(n_rows * n, stream); diff --git a/cpp/include/raft/sparse/neighbors/detail/knn.cuh b/cpp/include/raft/sparse/neighbors/detail/knn.cuh index 6649c10c47..527fc14208 100644 --- a/cpp/include/raft/sparse/neighbors/detail/knn.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/knn.cuh @@ -20,7 +20,6 @@ #include #include -#include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index c8fc6eefda..fe18da7f62 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -30,7 +30,7 @@ #include -#include +#include #include #include #include @@ -94,14 +94,14 @@ void sample_landmarks(raft::device_resources const& handle, (value_idx)index.n_landmarks, (value_idx)index.m); - raft::matrix::copyRows(index.get_X().data_handle(), - index.m, - index.n, - index.get_R().data_handle(), - R_1nn_cols2.data(), - index.n_landmarks, - handle.get_stream(), - true); + // index.get_X() returns the wrong indextype (uint32_t where we need value_idx), so need to + // create new device_matrix_view here + auto x = make_device_matrix_view( + index.get_X().data_handle(), index.m, index.n); + auto r = + make_device_matrix_view(index.get_R().data_handle(), index.m, index.n); + raft::matrix::copy_rows( + handle, x, r, make_device_vector_view(R_1nn_cols2.data(), index.n_landmarks)); } /** diff --git a/cpp/test/linalg/svd.cu b/cpp/test/linalg/svd.cu index bd66459962..a2bca891e5 100644 --- a/cpp/test/linalg/svd.cu +++ b/cpp/test/linalg/svd.cu @@ -18,7 +18,6 @@ #include #include #include -#include #include #include diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 10105203f7..07ab3c5ce4 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -143,7 +143,7 @@ class MatrixCopyRowsTest : public ::testing::Test { output.data(), n_selected, n_cols); auto indices_view = - raft::make_device_vector_view(indices.data(), n_selected); + raft::make_device_vector_view(indices.data(), n_selected); raft::matrix::copy_rows(handle, input_view, output_view, indices_view); diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 438c56da21..67df5f2abe 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -18,8 +18,8 @@ #include // raft::make_device_matrix #include +#include #include -#include #include #include @@ -188,8 +188,11 @@ auto eval_distances(raft::device_resources const& handle, auto y = raft::make_device_matrix(handle, k, n_cols); auto naive_dist = raft::make_device_matrix(handle, 1, k); - raft::matrix::copyRows( - x, k, n_cols, y.data_handle(), neighbors + i * k, k, handle.get_stream(), true); + raft::matrix::copy_rows( + handle, + make_device_matrix_view(x, k, n_cols), + y.view(), + make_device_vector_view(neighbors + i * k, k)); dim3 block_dim(16, 32, 1); auto grid_y = From 1de19cde17c93bb162ca04e83babf97ed1b90a04 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 8 May 2023 21:57:50 -0700 Subject: [PATCH 02/11] incremental progress --- cpp/include/raft/linalg/detail/eig.cuh | 30 ++++++++++++++----- cpp/include/raft/linalg/detail/qr.cuh | 7 +++-- cpp/include/raft/linalg/detail/rsvd.cuh | 9 ++++-- cpp/include/raft/matrix/copy.cuh | 18 +++++++++++ cpp/include/raft/matrix/triangular.cuh | 11 ++++--- .../raft/spatial/knn/detail/ball_cover.cuh | 12 ++++---- 6 files changed, 66 insertions(+), 21 deletions(-) diff --git a/cpp/include/raft/linalg/detail/eig.cuh b/cpp/include/raft/linalg/detail/eig.cuh index 8ba9f62910..7896136631 100644 --- a/cpp/include/raft/linalg/detail/eig.cuh +++ b/cpp/include/raft/linalg/detail/eig.cuh @@ -19,7 +19,7 @@ #include "cusolver_wrappers.hpp" #include #include -#include +#include #include #include #include @@ -52,7 +52,9 @@ void eigDC_legacy(raft::device_resources const& handle, rmm::device_uvector d_work(lwork, stream); rmm::device_scalar d_dev_info(stream); - raft::copy_async(eig_vectors, in, n_rows * n_cols, stream); + raft::matrix::copy(handle, + make_device_matrix_view(in, n_rows, n_cols), + make_device_matrix_view(eig_vectors, n_rows, n_cols)); RAFT_CUSOLVER_TRY(cusolverDnsyevd(cusolverH, CUSOLVER_EIG_MODE_VECTOR, @@ -108,7 +110,9 @@ void eigDC(raft::device_resources const& handle, rmm::device_scalar d_dev_info(stream); std::vector h_work(workspaceHost / sizeof(math_t)); - raft::copy_async(eig_vectors, in, n_rows * n_cols, stream); + raft::matrix::copy(handle, + make_device_matrix_view(in, n_rows, n_cols), + make_device_matrix_view(eig_vectors, n_rows, n_cols)); RAFT_CUSOLVER_TRY(cusolverDnxsyevd(cusolverH, dn_params, @@ -191,7 +195,9 @@ void eigSelDC(raft::device_resources const& handle, stream)); } else if (memUsage == COPY_INPUT) { d_eig_vectors.resize(n_rows * n_cols, stream); - raft::copy_async(d_eig_vectors.data(), in, n_rows * n_cols, stream); + raft::matrix::copy(handle, + make_device_matrix_view(in, n_rows, n_cols), + make_device_matrix_view(eig_vectors, n_rows, n_cols)); RAFT_CUSOLVER_TRY(cusolverDnsyevdx(cusolverH, CUSOLVER_EIG_MODE_VECTOR, @@ -220,10 +226,16 @@ void eigSelDC(raft::device_resources const& handle, "This usually occurs when some of the features do not vary enough."); if (memUsage == OVERWRITE_INPUT) { - raft::matrix::detail::truncZeroOrigin(in, n_rows, eig_vectors, n_rows, n_eig_vals, stream); + raft::matrix::trunc_zero_origin( + handle, + make_device_matrix_view(in, n_rows, n_eig_vals), + make_device_matrix_view(eig_vectors, n_rows, n_eig_vals)); } else if (memUsage == COPY_INPUT) { - raft::matrix::detail::truncZeroOrigin( - d_eig_vectors.data(), n_rows, eig_vectors, n_rows, n_eig_vals, stream); + raft::matrix::trunc_zero_origin( + handle, + make_device_matrix_view( + d_eig_vectors.data(), n_rows, n_eig_vals), + make_device_matrix_view(eig_vectors, n_rows, n_eig_vals)); } } @@ -259,7 +271,9 @@ void eigJacobi(raft::device_resources const& handle, rmm::device_uvector d_work(lwork, stream); rmm::device_scalar dev_info(stream); - raft::copy_async(eig_vectors, in, n_rows * n_cols, stream); + raft::matrix::copy(handle, + make_device_matrix_view(in, n_rows, n_cols), + make_device_matrix_view(eig_vectors, n_rows, n_cols)); RAFT_CUSOLVER_TRY(cusolverDnsyevj(cusolverH, CUSOLVER_EIG_MODE_VECTOR, diff --git a/cpp/include/raft/linalg/detail/qr.cuh b/cpp/include/raft/linalg/detail/qr.cuh index 721c02dc48..125abe9cf0 100644 --- a/cpp/include/raft/linalg/detail/qr.cuh +++ b/cpp/include/raft/linalg/detail/qr.cuh @@ -20,7 +20,7 @@ #include "cusolver_wrappers.hpp" #include #include -#include +#include #include #include @@ -132,7 +132,10 @@ void qrGetQR(raft::resources const& handle, devInfo.data(), stream)); - raft::matrix::detail::copyUpperTriangular(R_full.data(), R, m, n, stream); + raft::matrix::upper_triangular( + handle, + make_device_matrix_view(R_full.data(), m, n), + make_device_matrix_view(R, m, n)); RAFT_CUDA_TRY( cudaMemcpyAsync(Q, R_full.data(), sizeof(math_t) * m * n, cudaMemcpyDeviceToDevice, stream)); diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index 283b982e1f..6c1345830a 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -21,8 +21,8 @@ #include #include #include -#include #include +#include #include #include @@ -272,7 +272,12 @@ void rsvdFixedRank(raft::device_resources const& handle, RAFT_CUDA_TRY(cudaMemsetAsync(Uhat.data(), 0, sizeof(math_t) * l * l, stream)); rmm::device_uvector Uhat_dup(l * l, stream); RAFT_CUDA_TRY(cudaMemsetAsync(Uhat_dup.data(), 0, sizeof(math_t) * l * l, stream)); - raft::matrix::detail::copyUpperTriangular(BBt.data(), Uhat_dup.data(), l, l, stream); + + raft::matrix::upper_triangular( + handle, + make_device_matrix_view(BBt.data(), l, l), + make_device_matrix_view(Uhat_dup.data(), l, l)); + if (use_jacobi) raft::linalg::eigJacobi( handle, Uhat_dup.data(), l, l, Uhat.data(), S_vec_tmp.data(), stream, tol, max_sweeps); diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index f5bacc32dd..e4e5526e71 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -58,6 +58,24 @@ void copy_rows(raft::device_resources const& handle, raft::is_row_major(in)); } +/** + * @brief copy matrix operation for row major matrices. + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[out] out: output matrix + */ +template +void copy(raft::device_resources const& handle, + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1), + "Input and output matrix shapes must match."); + + raft::copy_async( + out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream()); +} + /** * @brief copy matrix operation for column major matrices. * @param[in] handle: raft handle diff --git a/cpp/include/raft/matrix/triangular.cuh b/cpp/include/raft/matrix/triangular.cuh index 1388c829b0..0c89140046 100644 --- a/cpp/include/raft/matrix/triangular.cuh +++ b/cpp/include/raft/matrix/triangular.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include namespace raft::matrix { @@ -33,17 +34,19 @@ namespace raft::matrix { * @param[out] dst: output matrix with a size of kxk, k = min(n_rows, n_cols) */ template -void upper_triangular(raft::device_resources const& handle, +void upper_triangular(raft::resources const& handle, raft::device_matrix_view src, raft::device_matrix_view dst) { auto k = std::min(src.extent(0), src.extent(1)); RAFT_EXPECTS(k == dst.extent(0) && k == dst.extent(1), "dst should be of size kxk, k = min(n_rows, n_cols)"); - detail::copyUpperTriangular( - src.data_handle(), dst.data_handle(), src.extent(0), src.extent(1), handle.get_stream()); + detail::copyUpperTriangular(src.data_handle(), + dst.data_handle(), + src.extent(0), + src.extent(1), + resource::get_cuda_stream(handle)); } - /** @} */ // end group matrix_triangular } // namespace raft::matrix diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index fe18da7f62..a58847ee41 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -96,12 +96,14 @@ void sample_landmarks(raft::device_resources const& handle, // index.get_X() returns the wrong indextype (uint32_t where we need value_idx), so need to // create new device_matrix_view here - auto x = make_device_matrix_view( - index.get_X().data_handle(), index.m, index.n); - auto r = - make_device_matrix_view(index.get_R().data_handle(), index.m, index.n); + auto x = index.get_X(); + auto r = index.get_R(); + raft::matrix::copy_rows( - handle, x, r, make_device_vector_view(R_1nn_cols2.data(), index.n_landmarks)); + handle, + make_device_matrix_view(x.data_handle(), x.extent(0), x.extent(1)), + make_device_matrix_view(r.data_handle(), r.extent(0), r.extent(1)), + make_device_vector_view(R_1nn_cols2.data(), index.n_landmarks)); } /** From 427b55fd2c32e93eb613fc5a011ee32cc371e583 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 8 May 2023 22:19:54 -0700 Subject: [PATCH 03/11] set_diagonal --- cpp/include/raft/linalg/detail/rsvd.cuh | 6 ++++-- cpp/include/raft/linalg/detail/svd.cuh | 6 ++++-- cpp/include/raft/matrix/diagonal.cuh | 14 ++++++++------ cpp/include/raft/random/detail/make_regression.cuh | 8 +++++--- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index 6c1345830a..310117414a 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -321,8 +322,9 @@ void rsvdFixedRank(raft::device_resources const& handle, rmm::device_uvector UhatSinv(l * k, stream); RAFT_CUDA_TRY(cudaMemsetAsync(UhatSinv.data(), 0, sizeof(math_t) * l * k, stream)); raft::matrix::reciprocal(S_vec_tmp.data(), l, stream); - raft::matrix::detail::initializeDiagonalMatrix( - S_vec_tmp.data() + p, Sinv.data(), k, k, stream); + raft::matrix::set_diagonal(handle, + make_device_vector_view(S_vec_tmp.data() + p, k), + make_device_matrix_view(Sinv.data(), k, k)); raft::linalg::gemm(handle, Uhat.data() + p * l, diff --git a/cpp/include/raft/linalg/detail/svd.cuh b/cpp/include/raft/linalg/detail/svd.cuh index d07e444003..30616ecd54 100644 --- a/cpp/include/raft/linalg/detail/svd.cuh +++ b/cpp/include/raft/linalg/detail/svd.cuh @@ -24,7 +24,7 @@ #include #include -#include +#include #include #include #include @@ -285,7 +285,9 @@ bool evaluateSVDByL2Norm(raft::device_resources const& handle, RAFT_CUDA_TRY(cudaMemsetAsync(P_d.data(), 0, sizeof(math_t) * m * n, stream)); RAFT_CUDA_TRY(cudaMemsetAsync(S_mat.data(), 0, sizeof(math_t) * k * k, stream)); - raft::matrix::detail::initializeDiagonalMatrix(S_vec, S_mat.data(), k, k, stream); + raft::matrix::set_diagonal(handle, + make_device_vector_view(S_vec, k), + make_device_matrix_view(S_mat.data(), k, k)); svdReconstruction(handle, U, S_mat.data(), V, P_d.data(), m, n, k, stream); // get norms of each diff --git a/cpp/include/raft/matrix/diagonal.cuh b/cpp/include/raft/matrix/diagonal.cuh index e0141cbf01..c7a3681983 100644 --- a/cpp/include/raft/matrix/diagonal.cuh +++ b/cpp/include/raft/matrix/diagonal.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include namespace raft::matrix { @@ -33,7 +34,7 @@ namespace raft::matrix { * @param[out] matrix: matrix of size n_rows x n_cols */ template -void set_diagonal(raft::device_resources const& handle, +void set_diagonal(raft::resources const& handle, raft::device_vector_view vec, raft::device_matrix_view matrix) { @@ -44,7 +45,7 @@ void set_diagonal(raft::device_resources const& handle, matrix.data_handle(), matrix.extent(0), matrix.extent(1), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -54,7 +55,7 @@ void set_diagonal(raft::device_resources const& handle, * @param[out] vec: vector of length k = min(n_rows, n_cols) */ template -void get_diagonal(raft::device_resources const& handle, +void get_diagonal(raft::resources const& handle, raft::device_matrix_view matrix, raft::device_vector_view vec) { @@ -64,7 +65,7 @@ void get_diagonal(raft::device_resources const& handle, matrix.data_handle(), matrix.extent(0), matrix.extent(1), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @@ -73,12 +74,13 @@ void get_diagonal(raft::device_resources const& handle, * @param[inout] inout: square input matrix with size len x len */ template -void invert_diagonal(raft::device_resources const& handle, +void invert_diagonal(raft::resources const& handle, raft::device_matrix_view inout) { // TODO: Use get_diagonal for this to support rectangular RAFT_EXPECTS(inout.extent(0) == inout.extent(1), "Matrix must be square."); - detail::getDiagonalInverseMatrix(inout.data_handle(), inout.extent(0), handle.get_stream()); + detail::getDiagonalInverseMatrix( + inout.data_handle(), inout.extent(0), resource::get_cuda_stream(handle)); } /** @} */ // end of group matrix_diagonal diff --git a/cpp/include/raft/random/detail/make_regression.cuh b/cpp/include/raft/random/detail/make_regression.cuh index 3a236d8834..aec1a15f84 100644 --- a/cpp/include/raft/random/detail/make_regression.cuh +++ b/cpp/include/raft/random/detail/make_regression.cuh @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include #include @@ -83,8 +83,10 @@ static void _make_low_rank_matrix(raft::resources const& handle, RAFT_CUDA_TRY(cudaPeekAtLastError()); rmm::device_uvector singular_mat(n * n, stream); RAFT_CUDA_TRY(cudaMemsetAsync(singular_mat.data(), 0, n * n * sizeof(DataT), stream)); - raft::matrix::detail::initializeDiagonalMatrix( - singular_vec.data(), singular_mat.data(), n, n, stream); + + raft::matrix::set_diagonal(handle, + make_device_vector_view(singular_vec.data(), n), + make_device_matrix_view(singular_mat.data(), n, n)); // Generate the column-major matrix rmm::device_uvector temp_q0s(n_rows * n, stream); From b1cb13759f67be6cd6ea6be9725a8f3afaf0e249 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 8 May 2023 22:35:08 -0700 Subject: [PATCH 04/11] l2_norm --- cpp/include/raft/linalg/detail/svd.cuh | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/linalg/detail/svd.cuh b/cpp/include/raft/linalg/detail/svd.cuh index 30616ecd54..d22cae6dcb 100644 --- a/cpp/include/raft/linalg/detail/svd.cuh +++ b/cpp/include/raft/linalg/detail/svd.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -291,11 +292,13 @@ bool evaluateSVDByL2Norm(raft::device_resources const& handle, svdReconstruction(handle, U, S_mat.data(), V, P_d.data(), m, n, k, stream); // get norms of each - math_t normA = raft::matrix::detail::getL2Norm(handle, A_d, m * n, stream); - math_t normU = raft::matrix::detail::getL2Norm(handle, U, m * k, stream); - math_t normS = raft::matrix::detail::getL2Norm(handle, S_mat.data(), k * k, stream); - math_t normV = raft::matrix::detail::getL2Norm(handle, V, n * k, stream); - math_t normP = raft::matrix::detail::getL2Norm(handle, P_d.data(), m * n, stream); + math_t normA = raft::matrix::l2_norm(handle, make_device_matrix_view(A_d, m, n)); + math_t normU = raft::matrix::l2_norm(handle, make_device_matrix_view(U, m, k)); + math_t normS = + raft::matrix::l2_norm(handle, make_device_matrix_view(S_mat.data(), k, k)); + math_t normV = raft::matrix::l2_norm(handle, make_device_matrix_view(V, n, k)); + math_t normP = + raft::matrix::l2_norm(handle, make_device_matrix_view(P_d.data(), m, n)); // calculate percent error const math_t alpha = 1.0, beta = -1.0; @@ -317,8 +320,9 @@ bool evaluateSVDByL2Norm(raft::device_resources const& handle, m, stream)); - math_t norm_A_minus_P = raft::matrix::detail::getL2Norm(handle, A_minus_P.data(), m * n, stream); - math_t percent_error = 100.0 * norm_A_minus_P / normA; + math_t norm_A_minus_P = + raft::matrix::l2_norm(handle, make_device_matrix_view(A_minus_P.data(), m, n)); + math_t percent_error = 100.0 * norm_A_minus_P / normA; return (percent_error / 100.0 < tol); } From bb9b0be5454a0fd384b6457e0510461cd2d2eb44 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 9 May 2023 10:15:49 -0700 Subject: [PATCH 05/11] fix rsvd unittests --- cpp/include/raft/linalg/detail/qr.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/linalg/detail/qr.cuh b/cpp/include/raft/linalg/detail/qr.cuh index 125abe9cf0..16a721dfd3 100644 --- a/cpp/include/raft/linalg/detail/qr.cuh +++ b/cpp/include/raft/linalg/detail/qr.cuh @@ -135,7 +135,7 @@ void qrGetQR(raft::resources const& handle, raft::matrix::upper_triangular( handle, make_device_matrix_view(R_full.data(), m, n), - make_device_matrix_view(R, m, n)); + make_device_matrix_view(R, std::min(m, n), std::min(m, n))); RAFT_CUDA_TRY( cudaMemcpyAsync(Q, R_full.data(), sizeof(math_t) * m * n, cudaMemcpyDeviceToDevice, stream)); From 867ff438d4ab211761ce883f817be9cda0d6eb99 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 9 May 2023 10:24:40 -0700 Subject: [PATCH 06/11] reverse --- cpp/include/raft/linalg/detail/rsvd.cuh | 7 ++++--- cpp/include/raft/linalg/detail/svd.cuh | 7 +++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index 310117414a..91ea4c9c4c 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -294,7 +295,7 @@ void rsvdFixedRank(raft::device_resources const& handle, 1, l, stream); // Last k elements of S_vec - raft::matrix::colReverse(S_vec, 1, k, stream); + raft::matrix::col_reverse(handle, make_device_matrix_view(S_vec, 1, k)); // Merge step 14 & 15 by calculating U = Q*Uhat[:,(p+1):l] mxl * lxk = mxk if (gen_left_vec) { @@ -311,7 +312,7 @@ void rsvdFixedRank(raft::device_resources const& handle, alpha, beta, stream); - raft::matrix::colReverse(U, m, k, stream); + raft::matrix::col_reverse(handle, make_device_matrix_view(U, m, k)); } // Merge step 14 & 15 by calculating V = B^T Uhat[:,(p+1):l] * @@ -352,7 +353,7 @@ void rsvdFixedRank(raft::device_resources const& handle, alpha, beta, stream); - raft::matrix::colReverse(V, n, k, stream); + raft::matrix::col_reverse(handle, make_device_matrix_view(V, n, k)); } } } diff --git a/cpp/include/raft/linalg/detail/svd.cuh b/cpp/include/raft/linalg/detail/svd.cuh index d22cae6dcb..94cd9e2789 100644 --- a/cpp/include/raft/linalg/detail/svd.cuh +++ b/cpp/include/raft/linalg/detail/svd.cuh @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -140,8 +141,10 @@ void svdEig(raft::device_resources const& handle, raft::linalg::eigDC(handle, in_cross_mult.data(), n_cols, n_cols, V, S, stream); - raft::matrix::colReverse(V, n_cols, n_cols, stream); - raft::matrix::rowReverse(S, n_cols, idx_t(1), stream); + raft::matrix::col_reverse(handle, + make_device_matrix_view(V, n_cols, n_cols)); + raft::matrix::row_reverse(handle, + make_device_matrix_view(S, n_cols, idx_t(1))); raft::matrix::seqRoot(S, S, alpha, n_cols, stream, true); From c599dd86eb9121b7a2de4d78ea0fdee493664060 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 9 May 2023 11:26:27 -0700 Subject: [PATCH 07/11] slice --- cpp/include/raft/linalg/detail/rsvd.cuh | 35 +++++++++++-------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index 91ea4c9c4c..48b9e1d2db 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -204,15 +205,13 @@ void rsvdFixedRank(raft::device_resources const& handle, true, true, stream); - raft::matrix::sliceMatrix(S_vec_tmp.data(), - 1, - l, - S_vec, - 0, - 0, - 1, - k, - stream); // First k elements of S_vec + + // First k elements of S_vec + raft::matrix::slice( + handle, + make_device_matrix_view(S_vec_tmp.data(), 1, l), + make_device_matrix_view(S_vec, 1, k), + raft::matrix::slice_coordinates(0, 0, 1, k)); // Merge step 14 & 15 by calculating U = Q*Vhat[:,1:k] mxl * lxk = mxk if (gen_left_vec) { @@ -286,16 +285,14 @@ void rsvdFixedRank(raft::device_resources const& handle, else raft::linalg::eigDC(handle, Uhat_dup.data(), l, l, Uhat.data(), S_vec_tmp.data(), stream); raft::matrix::seqRoot(S_vec_tmp.data(), l, stream); - raft::matrix::sliceMatrix(S_vec_tmp.data(), - 1, - l, - S_vec, - 0, - p, - 1, - l, - stream); // Last k elements of S_vec - raft::matrix::col_reverse(handle, make_device_matrix_view(S_vec, 1, k)); + + auto S_vec_view = make_device_matrix_view(S_vec, 1, k); + raft::matrix::slice( + handle, + raft::make_device_matrix_view(S_vec_tmp.data(), 1, l), + S_vec_view, + raft::matrix::slice_coordinates(0, p, 1, l)); // Last k elements of S_vec + raft::matrix::col_reverse(handle, S_vec_view); // Merge step 14 & 15 by calculating U = Q*Uhat[:,(p+1):l] mxl * lxk = mxk if (gen_left_vec) { From 87602aa7485550c42eab1cf68a37b7c5e294a797 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 9 May 2023 11:26:45 -0700 Subject: [PATCH 08/11] linewise_op --- .../raft/linalg/detail/matrix_vector_op.cuh | 53 +++++++++++++++---- cpp/include/raft/linalg/matrix_vector_op.cuh | 3 +- cpp/include/raft/matrix/linewise_op.cuh | 3 +- 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh index 62ec9bb7a4..c51ae2065a 100644 --- a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include +#include namespace raft { namespace linalg { @@ -33,10 +33,26 @@ void matrixVectorOp(MatT* out, Lambda op, cudaStream_t stream) { - IdxType stride = rowMajor ? D : N; - IdxType nLines = rowMajor ? N : D; - return matrix::linewiseOp( - out, matrix, stride, nLines, rowMajor == bcastAlongRows, op, stream, vec); + raft::device_resources handle(stream); + + bool along_lines = rowMajor == bcastAlongRows; + if (rowMajor) { + matrix::linewise_op( + handle, + make_device_matrix_view(matrix, D, N), + make_device_matrix_view(out, D, N), + along_lines, + op, + make_device_vector_view(vec, along_lines ? D : N)); + } else { + matrix::linewise_op( + handle, + make_device_matrix_view(matrix, D, N), + make_device_matrix_view(out, D, N), + along_lines, + op, + make_device_vector_view(vec, along_lines ? D : N)); + } } template ( + handle, + make_device_matrix_view(matrix, D, N), + make_device_matrix_view(out, D, N), + along_lines, + op, + make_device_vector_view(vec1, along_lines ? D : N), + make_device_vector_view(vec2, along_lines ? D : N)); + } else { + matrix::linewise_op( + handle, + make_device_matrix_view(matrix, D, N), + make_device_matrix_view(out, D, N), + along_lines, + op, + make_device_vector_view(vec1, along_lines ? D : N), + make_device_vector_view(vec2, along_lines ? D : N)); + } } }; // end namespace detail diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index 6c65626ac5..e8833a2779 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -22,6 +22,7 @@ #include "linalg_types.hpp" #include +#include #include namespace raft { @@ -241,4 +242,4 @@ void matrix_vector_op(raft::device_resources const& handle, }; // end namespace linalg }; // end namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh index 056ef4f411..f8e3555d9d 100644 --- a/cpp/include/raft/matrix/linewise_op.cuh +++ b/cpp/include/raft/matrix/linewise_op.cuh @@ -17,7 +17,8 @@ #pragma once #include -#include +#include +#include namespace raft::matrix { From 1d91909e4c81ea057b25cffa7f49e241089914f1 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 9 May 2023 15:08:22 -0700 Subject: [PATCH 09/11] fix --- cpp/include/raft/linalg/detail/matrix_vector_op.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh index c51ae2065a..1f83c3d099 100644 --- a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh @@ -39,11 +39,11 @@ void matrixVectorOp(MatT* out, if (rowMajor) { matrix::linewise_op( handle, - make_device_matrix_view(matrix, D, N), - make_device_matrix_view(out, D, N), + make_device_matrix_view(matrix, N, D), + make_device_matrix_view(out, N, D), along_lines, op, - make_device_vector_view(vec, along_lines ? D : N)); + make_device_vector_view(vec, along_lines ? N : D)); } else { matrix::linewise_op( handle, From 7e97f767e796957058f9e65bba9685dda95613cb Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 9 May 2023 15:12:06 -0700 Subject: [PATCH 10/11] . --- .../raft/linalg/detail/matrix_vector_op.cuh | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh index 1f83c3d099..6312e58b37 100644 --- a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh @@ -43,7 +43,7 @@ void matrixVectorOp(MatT* out, make_device_matrix_view(out, N, D), along_lines, op, - make_device_vector_view(vec, along_lines ? N : D)); + make_device_vector_view(vec, bcastAlongRows ? N : D)); } else { matrix::linewise_op( handle, @@ -51,7 +51,7 @@ void matrixVectorOp(MatT* out, make_device_matrix_view(out, D, N), along_lines, op, - make_device_vector_view(vec, along_lines ? D : N)); + make_device_vector_view(vec, bcastAlongRows ? D : N)); } } @@ -77,12 +77,12 @@ void matrixVectorOp(MatT* out, if (rowMajor) { matrix::linewise_op( handle, - make_device_matrix_view(matrix, D, N), - make_device_matrix_view(out, D, N), + make_device_matrix_view(matrix, N, D), + make_device_matrix_view(out, N, D), along_lines, op, - make_device_vector_view(vec1, along_lines ? D : N), - make_device_vector_view(vec2, along_lines ? D : N)); + make_device_vector_view(vec1, bcastAlongRows ? N : D), + make_device_vector_view(vec2, bcastAlongRows ? N : D)); } else { matrix::linewise_op( handle, @@ -90,8 +90,8 @@ void matrixVectorOp(MatT* out, make_device_matrix_view(out, D, N), along_lines, op, - make_device_vector_view(vec1, along_lines ? D : N), - make_device_vector_view(vec2, along_lines ? D : N)); + make_device_vector_view(vec1, bcastAlongRows ? D : N), + make_device_vector_view(vec2, bcastAlongRows ? D : N)); } } From c31158ba8a74c706d9a8df98efb095382866c0f6 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 9 May 2023 22:52:22 -0700 Subject: [PATCH 11/11] fix --- .../raft/linalg/detail/matrix_vector_op.cuh | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh index 6312e58b37..0c1261261c 100644 --- a/cpp/include/raft/linalg/detail/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/detail/matrix_vector_op.cuh @@ -47,11 +47,11 @@ void matrixVectorOp(MatT* out, } else { matrix::linewise_op( handle, - make_device_matrix_view(matrix, D, N), - make_device_matrix_view(out, D, N), + make_device_matrix_view(matrix, N, D), + make_device_matrix_view(out, N, D), along_lines, op, - make_device_vector_view(vec, bcastAlongRows ? D : N)); + make_device_vector_view(vec, bcastAlongRows ? N : D)); } } @@ -86,12 +86,12 @@ void matrixVectorOp(MatT* out, } else { matrix::linewise_op( handle, - make_device_matrix_view(matrix, D, N), - make_device_matrix_view(out, D, N), + make_device_matrix_view(matrix, N, D), + make_device_matrix_view(out, N, D), along_lines, op, - make_device_vector_view(vec1, bcastAlongRows ? D : N), - make_device_vector_view(vec2, bcastAlongRows ? D : N)); + make_device_vector_view(vec1, bcastAlongRows ? N : D), + make_device_vector_view(vec2, bcastAlongRows ? N : D)); } }