From d265e582522d2518686a510c4c54947410db7291 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Sun, 2 Oct 2022 04:53:19 -0400 Subject: [PATCH] Mdspanifying (currently tested) `raft::matrix` (#846) Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Mark Hoemmen (https://github.com/mhoemmen) - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/846 --- BUILD.md | 26 ++- .../raft/cluster/detail/single_linkage.cuh | 2 +- cpp/include/raft/cluster/single_linkage.cuh | 47 ++++- .../raft/cluster/single_linkage_types.hpp | 30 ++- cpp/include/raft/core/mdspan.hpp | 3 +- cpp/include/raft/matrix/col_wise_sort.cuh | 107 ++++++++++- cpp/include/raft/matrix/copy.cuh | 97 ++++++++++ cpp/include/raft/matrix/detail/math.cuh | 18 +- cpp/include/raft/matrix/detail/matrix.cuh | 3 +- cpp/include/raft/matrix/detail/print.hpp | 48 +++++ cpp/include/raft/matrix/gather.cuh | 180 +++++++++++++++++- cpp/include/raft/matrix/init.cuh | 45 +++++ cpp/include/raft/matrix/linewise_op.cuh | 85 +++++++++ cpp/include/raft/matrix/math.cuh | 13 +- cpp/include/raft/matrix/matrix.cuh | 28 +++ cpp/include/raft/matrix/matrix_types.hpp | 26 +++ cpp/include/raft/matrix/power.cuh | 94 +++++++++ cpp/include/raft/matrix/print.cuh | 47 +++++ cpp/include/raft/matrix/print.hpp | 36 ++++ cpp/include/raft/matrix/ratio.cuh | 56 ++++++ cpp/include/raft/matrix/reciprocal.cuh | 81 ++++++++ cpp/include/raft/matrix/sign_flip.cuh | 39 ++++ cpp/include/raft/matrix/sqrt.cuh | 105 ++++++++++ cpp/include/raft/matrix/threshold.cuh | 61 ++++++ cpp/include/raft/spatial/knn/ball_cover.cuh | 8 +- .../raft/spatial/knn/ball_cover_types.hpp | 65 +++++-- cpp/include/raft/spatial/knn/brute_force.cuh | 6 +- .../raft/spatial/knn/detail/ball_cover.cuh | 13 +- .../knn/detail/ball_cover/registers.cuh | 6 +- .../knn/specializations/ball_cover.cuh | 2 +- .../detail/ball_cover_lowdim.hpp | 8 +- cpp/include/raft/util/input_validation.hpp | 87 +++++++++ cpp/src/nn/specializations/ball_cover.cu | 2 +- .../detail/ball_cover_lowdim_pass_one_2d.cu | 2 +- .../detail/ball_cover_lowdim_pass_one_3d.cu | 4 +- .../detail/ball_cover_lowdim_pass_two_2d.cu | 2 +- .../detail/ball_cover_lowdim_pass_two_3d.cu | 2 +- cpp/test/CMakeLists.txt | 3 +- cpp/test/matrix/columnSort.cu | 74 +++---- cpp/test/matrix/gather.cu | 28 ++- cpp/test/matrix/linewise_op.cu | 60 ++++-- cpp/test/matrix/math.cu | 46 ++++- cpp/test/matrix/matrix.cu | 48 +++-- cpp/test/sparse/linkage.cu | 26 +-- 44 files changed, 1559 insertions(+), 210 deletions(-) create mode 100644 cpp/include/raft/matrix/copy.cuh create mode 100644 cpp/include/raft/matrix/detail/print.hpp create mode 100644 cpp/include/raft/matrix/init.cuh create mode 100644 cpp/include/raft/matrix/linewise_op.cuh create mode 100644 cpp/include/raft/matrix/matrix_types.hpp create mode 100644 cpp/include/raft/matrix/power.cuh create mode 100644 cpp/include/raft/matrix/print.cuh create mode 100644 cpp/include/raft/matrix/print.hpp create mode 100644 cpp/include/raft/matrix/ratio.cuh create mode 100644 cpp/include/raft/matrix/reciprocal.cuh create mode 100644 cpp/include/raft/matrix/sign_flip.cuh create mode 100644 cpp/include/raft/matrix/sqrt.cuh create mode 100644 cpp/include/raft/matrix/threshold.cuh diff --git a/BUILD.md b/BUILD.md index 0c7fdd7e82..f572b11848 100644 --- a/BUILD.md +++ b/BUILD.md @@ -5,6 +5,7 @@ - [Build Dependencies](#required_depenencies) - [Header-only C++](#install_header_only_cpp) - [C++ Shared Libraries](#shared_cpp_libs) + - [Improving Rebuild Times](#ccache) - [Googletests](#gtests) - [C++ Using Cmake](#cpp_using_cmake) - [Python](#python) @@ -29,7 +30,6 @@ In addition to the libraries included with cudatoolkit 11.0+, there are some oth - [RMM](https://github.com/rapidsai/rmm) corresponding to RAFT version. #### Optional -- [mdspan](https://github.com/rapidsai/mdspan) - On by default but can be disabled. - [Thrust](https://github.com/NVIDIA/thrust) v1.15 / [CUB](https://github.com/NVIDIA/cub) - On by default but can be disabled. - [cuCollections](https://github.com/NVIDIA/cuCollections) - Used in `raft::sparse::distance` API. - [Libcu++](https://github.com/NVIDIA/libcudacxx) v1.7.0 @@ -53,11 +53,6 @@ The following example will download the needed dependencies and install the RAFT ./build.sh libraft --install ``` -The `--minimal-deps` flag can be used to install the headers with minimal dependencies: -```bash -./build.sh libraft --install --minimal-deps -``` - ### C++ Shared Libraries (optional) For larger projects which make heavy use of the pairwise distances or nearest neighbors APIs, shared libraries can be built to speed up compile times. These shared libraries can also significantly improve re-compile times both while developing RAFT and developing against the APIs. Build all of the available shared libraries by passing `--compile-libs` flag to `build.sh`: @@ -72,6 +67,14 @@ Individual shared libraries have their own flags and multiple can be used (thoug Add the `--install` flag to the above example to also install the shared libraries into `$INSTALL_PREFIX/lib`. +### `ccache` and `sccache` + +`ccache` and `sccache` can be used to better cache parts of the build when rebuilding frequently, such as when working on a new feature. You can also use `ccache` or `sccache` with `build.sh`: + +```bash +./build.sh libraft --cache-tool=ccache +``` + ### Tests Compile the tests using the `tests` target in `build.sh`. @@ -86,10 +89,17 @@ Test compile times can be improved significantly by using the optional shared li ./build.sh libraft tests --compile-libs ``` -To run C++ tests: +The tests are broken apart by algorithm category, so you will find several binaries in `cpp/build/` named `*_TEST`. + +For example, to run the distance tests: +```bash +./cpp/build/DISTANCE_TEST +``` + +It can take sometime to compile all of the tests. You can build individual tests by providing a semicolon-separated list to the `--limit-tests` option in `build.sh`: ```bash -./cpp/build/test_raft +./build.sh libraft tests --limit-tests=SPATIAL_TEST;DISTANCE_TEST;MATRIX_TEST ``` ### Benchmarks diff --git a/cpp/include/raft/cluster/detail/single_linkage.cuh b/cpp/include/raft/cluster/detail/single_linkage.cuh index 7de942444e..9eee21b09c 100644 --- a/cpp/include/raft/cluster/detail/single_linkage.cuh +++ b/cpp/include/raft/cluster/detail/single_linkage.cuh @@ -54,7 +54,7 @@ void single_linkage(const raft::handle_t& handle, size_t m, size_t n, raft::distance::DistanceType metric, - linkage_output* out, + linkage_output* out, int c, size_t n_clusters) { diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh index 98735c74e4..8e33b8389d 100644 --- a/cpp/include/raft/cluster/single_linkage.cuh +++ b/cpp/include/raft/cluster/single_linkage.cuh @@ -17,9 +17,12 @@ #include #include +#include namespace raft::cluster { +constexpr int DEFAULT_CONST_C = 15; + /** * Single-linkage clustering, capable of constructing a KNN graph to * scale the algorithm beyond the n^2 memory consumption of implementations @@ -48,11 +51,53 @@ void single_linkage(const raft::handle_t& handle, size_t m, size_t n, raft::distance::DistanceType metric, - linkage_output* out, + linkage_output* out, int c, size_t n_clusters) { detail::single_linkage( handle, X, m, n, metric, out, c, n_clusters); } + +/** + * Single-linkage clustering, capable of constructing a KNN graph to + * scale the algorithm beyond the n^2 memory consumption of implementations + * that use the fully-connected graph of pairwise distances by connecting + * a knn graph when k is not large enough to connect it. + + * @tparam value_idx + * @tparam value_t + * @tparam dist_type method to use for constructing connectivities graph + * @param[in] handle raft handle + * @param[in] X dense input matrix in row-major layout + * @param[out] dendrogram output dendrogram (size [n_rows - 1] * 2) + * @param[out] labels output labels vector (size n_rows) + * @param[in] metric distance metrix to use when constructing connectivities graph + * @param[in] n_clusters number of clusters to assign data samples + * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect + control of k. The algorithm will set `k = log(n) + c` + */ +template +void single_linkage(const raft::handle_t& handle, + raft::device_matrix_view X, + raft::device_matrix_view dendrogram, + raft::device_vector_view labels, + raft::distance::DistanceType metric, + size_t n_clusters, + std::optional c = std::make_optional(DEFAULT_CONST_C)) +{ + linkage_output out_arrs; + out_arrs.children = dendrogram.data_handle(); + out_arrs.labels = labels.data_handle(); + + single_linkage(handle, + X.data_handle(), + static_cast(X.extent(0)), + static_cast(X.extent(1)), + metric, + &out_arrs, + c.has_value() ? c.value() : DEFAULT_CONST_C, + n_clusters); +} + }; // namespace raft::cluster diff --git a/cpp/include/raft/cluster/single_linkage_types.hpp b/cpp/include/raft/cluster/single_linkage_types.hpp index 1c35cf5c68..d97e6afed3 100644 --- a/cpp/include/raft/cluster/single_linkage_types.hpp +++ b/cpp/include/raft/cluster/single_linkage_types.hpp @@ -16,6 +16,8 @@ #pragma once +#include + namespace raft::cluster { enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 }; @@ -27,23 +29,33 @@ enum LinkageDistance { PAIRWISE = 0, KNN_GRAPH = 1 }; * @tparam value_idx * @tparam value_t */ -template +template class linkage_output { public: - value_idx m; - value_idx n_clusters; + idx_t m; + idx_t n_clusters; + + idx_t n_leaves; + idx_t n_connected_components; - value_idx n_leaves; - value_idx n_connected_components; + // TODO: These will be made private in a future release + idx_t* labels; // size: m + idx_t* children; // size: (m-1, 2) - value_idx* labels; // size: m + raft::device_vector_view get_labels() + { + return raft::make_device_vector_view(labels, m); + } - value_idx* children; // size: (m-1, 2) + raft::device_matrix_view get_children() + { + return raft::make_device_matrix_view(children, m - 1, 2); + } }; -class linkage_output_int_float : public linkage_output { +class linkage_output_int_float : public linkage_output { }; -class linkage_output__int64_float : public linkage_output { +class linkage_output__int64_float : public linkage_output { }; }; // namespace raft::cluster \ No newline at end of file diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 6281ca98ea..a858633e07 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -255,5 +255,4 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, return unravel_index_impl(static_cast(idx), shape); } } - -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index afdec24ebd..d26f5f73cf 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -18,10 +18,11 @@ #pragma once +#include +#include #include -namespace raft { -namespace matrix { +namespace raft::matrix { /** * @brief sort columns within each row of row-major input matrix and return sorted indexes @@ -50,7 +51,105 @@ void sort_cols_per_row(const InType* in, detail::sortColumnsPerRow( in, out, n_rows, n_columns, bAllocWorkspace, workspacePtr, workspaceSize, stream, sortedKeys); } -}; // end namespace matrix -}; // end namespace raft + +/** + * @brief sort columns within each row of row-major input matrix and return sorted indexes + * modelled as key-value sort with key being input matrix and value being index of values + * @tparam in_t: element type of input matrix + * @tparam out_t: element type of output matrix + * @tparam matrix_idx_t: integer type for matrix indexing + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[out] out: output value(index) matrix + * @param[out] sorted_keys: Optional, output matrix for sorted keys (input) + */ +template +void sort_cols_per_row(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + std::optional> + sorted_keys = std::nullopt) +{ + RAFT_EXPECTS(in.extent(1) == out.extent(1) && in.extent(0) == out.extent(0), + "Input and output matrices must have the same shape."); + + if (sorted_keys.has_value()) { + RAFT_EXPECTS(in.extent(1) == sorted_keys.value().extent(1) && + in.extent(0) == sorted_keys.value().extent(0), + "Input and `sorted_keys` matrices must have the same shape."); + } + + size_t workspace_size = 0; + bool alloc_workspace = false; + + in_t* keys = sorted_keys.has_value() ? sorted_keys.value().data_handle() : nullptr; + + detail::sortColumnsPerRow(in.data_handle(), + out.data_handle(), + in.extent(0), + in.extent(1), + alloc_workspace, + (void*)nullptr, + workspace_size, + handle.get_stream(), + keys); + + if (alloc_workspace) { + auto workspace = raft::make_device_vector(handle, workspace_size); + + detail::sortColumnsPerRow(in.data_handle(), + out.data_handle(), + in.extent(0), + in.extent(1), + alloc_workspace, + (void*)workspace.data_handle(), + workspace_size, + handle.get_stream(), + keys); + } +} + +namespace sort_cols_per_row_impl { +template +struct sorted_keys_alias { +}; + +template <> +struct sorted_keys_alias { + using type = double; +}; + +template +struct sorted_keys_alias< + std::optional>> { + using type = typename raft::device_matrix_view::value_type; +}; + +template +using sorted_keys_t = typename sorted_keys_alias::type; +} // namespace sort_cols_per_row_impl + +/** + * @brief Overload of `sort_keys_per_row` to help the + * compiler find the above overload, in case users pass in + * `std::nullopt` for one or both of the optional arguments. + * + * Please see above for documentation of `sort_keys_per_row`. + */ +template +void sort_cols_per_row(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + sorted_keys_vector_type sorted_keys) +{ + using sorted_keys_type = sort_cols_per_row_impl::sorted_keys_t< + std::remove_const_t>>; + std::optional> sorted_keys_opt = + std::forward(sorted_keys); + + sort_cols_per_row(handle, in, out, sorted_keys_opt); +} + +}; // end namespace raft::matrix #endif \ No newline at end of file diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh new file mode 100644 index 0000000000..5f1d16485c --- /dev/null +++ b/cpp/include/raft/matrix/copy.cuh @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Copy selected rows of the input matrix into contiguous space. + * + * On exit out[i + k*n_rows] = in[indices[i] + k*n_rows], + * where i = 0..n_rows_indices-1, and k = 0..n_cols-1. + * + * @param[in] handle raft handle + * @param[in] in input matrix + * @param[out] out output matrix + * @param[in] indices of the rows to be copied + */ +template +void copy_rows(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::device_vector_view indices) +{ + RAFT_EXPECTS(in.extent(1) == out.extent(1), + "Input and output matrices must have same number of columns"); + RAFT_EXPECTS(indices.extent(0) == out.extent(0), + "Number of rows in output matrix must equal number of indices"); + detail::copyRows(in.data_handle(), + in.extent(0), + in.extent(1), + out.data_handle(), + indices.data_handle(), + indices.extent(0), + handle.get_stream(), + raft::is_row_major(in)); +} + +/** + * @brief copy matrix operation for column major matrices. + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[out] out: output matrix + */ +template +void copy(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1), + "Input and output matrix shapes must match."); + + raft::copy_async( + out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream()); +} + +/** + * @brief copy matrix operation for column major matrices. First n_rows and + * n_cols of input matrix "in" is copied to "out" matrix. + * @param handle: raft handle for managing resources + * @param in: input matrix + * @param out: output matrix + */ +template +void trunc_zero_origin(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(out.extent(0) <= in.extent(0) && out.extent(1) <= in.extent(1), + "Output matrix must have less or equal number of rows and columns"); + + detail::truncZeroOrigin(in.data_handle(), + in.extent(0), + out.data_handle(), + out.extent(0), + out.extent(1), + handle.get_stream()); +} + +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index 95953feca4..07b9ccc12b 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -141,7 +141,7 @@ void setSmallValuesZero(math_t* inout, IdxType len, cudaStream_t stream, math_t } template -void reciprocal(math_t* in, +void reciprocal(const math_t* in, math_t* out, math_t scalar, int len, @@ -363,8 +363,8 @@ void matrixVectorBinarySub(Type* data, } // Computes the argmax(d_in) column-wise in a DxN matrix -template -__global__ void argmaxKernel(const T* d_in, int D, int N, T* argmax) +template +__global__ void argmaxKernel(const T* d_in, int D, int N, IdxT* argmax) { typedef cub::BlockReduce, TPB> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -384,19 +384,19 @@ __global__ void argmaxKernel(const T* d_in, int D, int N, T* argmax) if (threadIdx.x == 0) { argmax[blockIdx.x] = maxKV.key; } } -template -void argmax(const math_t* in, int n_rows, int n_cols, math_t* out, cudaStream_t stream) +template +void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream) { int D = n_rows; int N = n_cols; if (D <= 32) { - argmaxKernel<<>>(in, D, N, out); + argmaxKernel<<>>(in, D, N, out); } else if (D <= 64) { - argmaxKernel<<>>(in, D, N, out); + argmaxKernel<<>>(in, D, N, out); } else if (D <= 128) { - argmaxKernel<<>>(in, D, N, out); + argmaxKernel<<>>(in, D, N, out); } else { - argmaxKernel<<>>(in, D, N, out); + argmaxKernel<<>>(in, D, N, out); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index a8568b0859..c425aad79b 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -67,7 +67,7 @@ void copyRows(const m_t* in, template void truncZeroOrigin( - m_t* in, idx_t in_n_rows, m_t* out, idx_t out_n_rows, idx_t out_n_cols, cudaStream_t stream) + const m_t* in, idx_t in_n_rows, m_t* out, idx_t out_n_rows, idx_t out_n_cols, cudaStream_t stream) { auto m = out_n_rows; auto k = in_n_rows; @@ -279,7 +279,6 @@ m_t getL2Norm(const raft::handle_t& handle, m_t* in, idx_t size, cudaStream_t st { cublasHandle_t cublasH = handle.get_cublas_handle(); m_t normval = 0; - // #TODO: Call from the public API when ready RAFT_CUBLAS_TRY(raft::linalg::detail::cublasnrm2(cublasH, size, in, 1, &normval, stream)); return normval; } diff --git a/cpp/include/raft/matrix/detail/print.hpp b/cpp/include/raft/matrix/detail/print.hpp new file mode 100644 index 0000000000..fc3d14861c --- /dev/null +++ b/cpp/include/raft/matrix/detail/print.hpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace raft::matrix::detail { + +template +void printHost( + const m_t* in, idx_t n_rows, idx_t n_cols, char h_separator = ' ', char v_separator = '\n', ) +{ + for (idx_t i = 0; i < n_rows; i++) { + for (idx_t j = 0; j < n_cols; j++) { + printf("%1.4f%c", in[j * n_rows + i], j < n_cols - 1 ? h_separator : v_separator); + } + } +} + +} // end namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 31164b2041..fa6e73de49 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -15,10 +15,12 @@ */ #pragma once + +#include +#include #include -namespace raft { -namespace matrix { +namespace raft::matrix { /** * @brief gather copies rows from a source matrix into a destination matrix according to a map. @@ -49,6 +51,76 @@ void gather(const MatrixIteratorT in, detail::gather(in, D, N, map, map_length, out, stream); } +/** + * @brief gather copies rows from a source matrix into a destination matrix according to a map. + * + * @tparam matrix_t Matrix element type + * @tparam map_t Map vector type + * @tparam idx_t integer type used for indexing + * @param[in] handle raft handle for managing resources + * @param[in] in Input matrix (assumed to be row-major) + * @param[in] map Vector of gather locations + * @param[out] out Output matrix (assumed to be row-major) + */ +template +void gather(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view map, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(out.extent(0) == map.extent(0), + "Number of rows in output matrix must equal the size of the map vector"); + RAFT_EXPECTS(out.extent(1) == in.extent(1), + "Number of columns in input and output matrices must be equal."); + + raft::matrix::detail::gather( + const_cast(in.data_handle()), // TODO: There's a better way to handle this + static_cast(in.extent(1)), + static_cast(in.extent(0)), + map.data_handle(), + static_cast(map.extent(0)), + out.data_handle(), + handle.get_stream()); +} + +/** + * @brief gather copies rows from a source matrix into a destination matrix according to a + * transformed map. + * + * @tparam matrix_t Matrix type + * @tparam map_t Map vector type + * @tparam map_xform_t Unary lambda expression or operator type, MapTransformOp's result + * type must be convertible to idx_t (= int) type. + * @tparam idx_t integer type for indexing + * @param[in] handle raft handle for managing resources + * @param[in] in Input matrix (assumed to be row-major) + * @param[in] map Input vector of gather locations + * @param[out] out Output matrix (assumed to be row-major) + * @param[in] transform_op The transformation operation, transforms the map values to idx_t + */ +template +void gather(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_vector_view map, + raft::device_matrix_view out, + map_xform_t transform_op) +{ + RAFT_EXPECTS(out.extent(0) == map.extent(0), + "Number of rows in output matrix must equal the size of the map vector"); + RAFT_EXPECTS(out.extent(1) == in.extent(1), + "Number of columns in input and output matrices must be equal."); + + detail::gather( + const_cast(in.data_handle()), // TODO: There's a better way to handle this + static_cast(in.extent(1)), + static_cast(in.extent(0)), + map, + static_cast(map.extent(0)), + out.data_handle(), + transform_op, + handle.get_stream()); +} + /** * @brief gather copies rows from a source matrix into a destination matrix according to a * transformed map. @@ -124,6 +196,53 @@ void gather_if(const MatrixIteratorT in, detail::gather_if(in, D, N, map, stencil, map_length, out, pred_op, stream); } +/** + * @brief gather_if conditionally copies rows from a source matrix into a destination matrix + * according to a map. + * + * @tparam matrix_t Matrix value type + * @tparam map_t Map vector type + * @tparam stencil_t Stencil vector type + * @tparam unary_pred_t Unary lambda expression or operator type, unary_pred_t's result + * type must be convertible to bool type. + * @tparam idx_t integer type for indexing + * @param[in] handle raft handle for managing resources + * @param[in] in Input matrix (assumed to be row-major) + * @param[in] map Input vector of gather locations + * @param[in] stencil Input vector of stencil or predicate values + * @param[out] out Output matrix (assumed to be row-major) + * @param[in] pred_op Predicate to apply to the stencil values + */ +template +void gather_if(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::device_vector_view map, + raft::device_vector_view stencil, + unary_pred_t pred_op) +{ + RAFT_EXPECTS(out.extent(0) == map.extent(0), + "Number of rows in output matrix must equal the size of the map vector"); + RAFT_EXPECTS(out.extent(1) == in.extent(1), + "Number of columns in input and output matrices must be equal."); + RAFT_EXPECTS(map.extent(0) == stencil.extent(0), + "Number of elements in stencil must equal number of elements in map"); + + detail::gather_if(const_cast(in.data_handle()), + out.extent(1), + out.extent(0), + map.data_handle(), + stencil.data_handle(), + map.extent(0), + out.data_handle(), + pred_op, + handle.get_stream()); +} + /** * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a transformed map. @@ -169,5 +288,58 @@ void gather_if(const MatrixIteratorT in, { detail::gather_if(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream); } -} // namespace matrix -} // namespace raft + +/** + * @brief gather_if conditionally copies rows from a source matrix into a destination matrix + * according to a transformed map. + * + * @tparam matrix_t Matrix value type, for reading input matrix + * @tparam map_t Vector value type for map + * @tparam stencil_t Vector value type for stencil + * @tparam unary_pred_t Unary lambda expression or operator type, unary_pred_t's result + * type must be convertible to bool type. + * @tparam map_xform_t Unary lambda expression or operator type, map_xform_t's result + * type must be convertible to idx_t (= int) type. + * @tparam idx_t integer type for indexing + * @param[in] handle raft handle for managing resources + * @param[in] in Input matrix (assumed to be row-major) + * @param[in] map Vector of gather locations + * @param[in] stencil Vector of stencil or predicate values + * @param[out] out Output matrix (assumed to be row-major) + * @param[in] pred_op Predicate to apply to the stencil values + * @param[in] transform_op The transformation operation, transforms the map values to idx_t + */ +template +void gather_if(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::device_vector_view map, + raft::device_vector_view stencil, + unary_pred_t pred_op, + map_xform_t transform_op) +{ + RAFT_EXPECTS(out.extent(0) == map.extent(0), + "Number of rows in output matrix must equal the size of the map vector"); + RAFT_EXPECTS(out.extent(1) == in.extent(1), + "Number of columns in input and output matrices must be equal."); + RAFT_EXPECTS(map.extent(0) == stencil.extent(0), + "Number of elements in stencil must equal number of elements in map"); + + detail::gather_if(const_cast(in.data_handle()), + in.extent(1), + in.extent(0), + map.data_handle(), + stencil.data_handle(), + map.extent(0), + out.data_handle(), + pred_op, + transform_op, + handle.get_stream()); +} + +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/init.cuh b/cpp/include/raft/matrix/init.cuh new file mode 100644 index 0000000000..e3a6c09fe6 --- /dev/null +++ b/cpp/include/raft/matrix/init.cuh @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft::matrix { +/** + * @brief set values to scalar in matrix + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle + * @param[in] in input matrix + * @param[out] out output matrix. The result is stored in the out matrix + * @param[in] scalar scalar value to fill matrix elements + */ +template +void fill(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::host_scalar_view scalar) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be the same size."); + detail::setValue( + out.data_handle(), in.data_handle(), *(scalar.data_handle()), in.size(), handle.get_stream()); +} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/linewise_op.cuh b/cpp/include/raft/matrix/linewise_op.cuh new file mode 100644 index 0000000000..6b383b14f5 --- /dev/null +++ b/cpp/include/raft/matrix/linewise_op.cuh @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * Run a function over matrix lines (rows or columns) with a variable number + * row-vectors or column-vectors. + * The term `line` here signifies that the lines can be either columns or rows, + * depending on the matrix layout. + * What matters is if the vectors are applied along lines (indices of vectors correspond to + * indices within lines), or across lines (indices of vectors correspond to line numbers). + * @tparam m_t matrix elements type + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @tparam Lambda type of lambda function used for the operation + * @tparam vec_t variadic types of device_vector_view vectors (size m if alongRows, size n + * otherwise) + * @param[in] handle raft handle for managing resources + * @param [out] out result of the operation; can be same as `in`; should be aligned the same + * as `in` to allow faster vectorized memory transfers. + * @param [in] in input matrix consisting of `nLines` lines, each `lineLen`-long. + * @param [in] alongLines whether vectors are indices along or across lines. + * @param [in] op the operation applied on each line: + * for i in [0..lineLen) and j in [0..nLines): + * out[i, j] = op(in[i, j], vec1[i], vec2[i], ... veck[i]) if alongLines = true + * out[i, j] = op(in[i, j], vec1[j], vec2[j], ... veck[j]) if alongLines = false + * where matrix indexing is row-major ([i, j] = [i + lineLen * j]). + * @param [in] vecs zero or more vectors to be passed as arguments, + * size of each vector is `alongLines ? lineLen : nLines`. + */ +template > +void linewise_op(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + const bool alongLines, + Lambda op, + vec_t... vecs) +{ + constexpr auto is_rowmajor = std::is_same_v; + constexpr auto is_colmajor = std::is_same_v; + + static_assert(is_rowmajor || is_colmajor, + "layout for in and out must be either row or col major"); + + const idx_t lineLen = is_rowmajor ? in.extent(0) : in.extent(1); + const idx_t nLines = is_rowmajor ? in.extent(1) : in.extent(0); + + RAFT_EXPECTS(out.extent(0) == in.extent(0) && out.extent(1) == in.extent(1), + "Input and output must have the same shape."); + + detail::MatrixLinewiseOp<16, 256>::run(out.data_handle(), + in.data_handle(), + lineLen, + nLines, + alongLines, + op, + handle.get_stream(), + vecs.data_handle()...); +} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/math.cuh b/cpp/include/raft/matrix/math.cuh index 9e103afda5..3c2705cf87 100644 --- a/cpp/include/raft/matrix/math.cuh +++ b/cpp/include/raft/matrix/math.cuh @@ -14,6 +14,15 @@ * limitations under the License. */ +/** + * This file is deprecated and will be removed in a future release. + * Please use versions in individual header files instead. + */ + +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use versions in individual header files instead.") + #ifndef __MATH_H #define __MATH_H @@ -301,8 +310,8 @@ void ratio( * @param out: output vector of size n_cols * @param stream: cuda stream */ -template -void argmax(const math_t* in, int n_rows, int n_cols, math_t* out, cudaStream_t stream) +template +void argmax(const math_t* in, int n_rows, int n_cols, idx_t* out, cudaStream_t stream) { detail::argmax(in, n_rows, n_cols, out, stream); } diff --git a/cpp/include/raft/matrix/matrix.cuh b/cpp/include/raft/matrix/matrix.cuh index 1af7e37dec..3a7e0dad47 100644 --- a/cpp/include/raft/matrix/matrix.cuh +++ b/cpp/include/raft/matrix/matrix.cuh @@ -14,6 +14,15 @@ * limitations under the License. */ +/** + * This file is deprecated and will be removed in a future release. + * Please use versions in individual header files instead. + */ + +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use versions in individual header files instead.") + #ifndef __MATRIX_H #define __MATRIX_H @@ -21,6 +30,7 @@ #include "detail/linewise_op.cuh" #include "detail/matrix.cuh" +#include #include @@ -71,6 +81,24 @@ void copy(const m_t* in, m_t* out, idx_t n_rows, idx_t n_cols, cudaStream_t stre raft::copy_async(out, in, n_rows * n_cols, stream); } +/** + * @brief copy matrix operation for column major matrices. + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[out] out: output matrix + */ +template +void copy(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(in.extent(0) == out.extent(0) && in.extent(1) == out.extent(1), + "Input and output matrix shapes must match."); + + raft::copy_async( + out.data_handle(), in.data_handle(), in.extent(0) * out.extent(1), handle.get_stream()); +} + /** * @brief copy matrix operation for column major matrices. First n_rows and * n_cols of input matrix "in" is copied to "out" matrix. diff --git a/cpp/include/raft/matrix/matrix_types.hpp b/cpp/include/raft/matrix/matrix_types.hpp new file mode 100644 index 0000000000..1f22154627 --- /dev/null +++ b/cpp/include/raft/matrix/matrix_types.hpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::matrix { + +struct print_separators { + char horizontal = ' '; + char vertical = '\n'; +}; + +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/power.cuh b/cpp/include/raft/matrix/power.cuh new file mode 100644 index 0000000000..4e2b3b7d72 --- /dev/null +++ b/cpp/include/raft/matrix/power.cuh @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::matrix { + +/** + * @brief Power of every element in the input matrix + * @tparam math_t type of matrix elements + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[out] out: output matrix. The result is stored in the out matrix + * @param[in] scalar: every element is multiplied with scalar. + */ +template +void weighted_power(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + math_t scalar) +{ + RAFT_EXPECTS(in.size() == out.size(), "Size of input and output matrices must be equal"); + detail::power(in.data_handle(), out.data_handle(), scalar, in.size(), handle.get_stream()); +} + +/** + * @brief Power of every element in the input matrix (inplace) + * @tparam math_t matrix element type + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle + * @param[inout] inout: input matrix and also the result is stored + * @param[in] scalar: every element is multiplied with scalar. + */ +template +void weighted_power(const raft::handle_t& handle, + raft::device_matrix_view inout, + math_t scalar) +{ + detail::power(inout.data_handle(), scalar, inout.size(), handle.get_stream()); +} + +/** + * @brief Power of every element in the input matrix (inplace) + * @tparam math_t matrix element type + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle + * @param[inout] inout: input matrix and also the result is stored + */ +template +void power(const raft::handle_t& handle, raft::device_matrix_view inout) +{ + detail::power(inout.data_handle(), inout.size(), handle.get_stream()); +} + +/** + * @brief Power of every element in the input matrix + * @tparam math_t type used for matrix elements + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix (row or column major) + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[out] out: output matrix. The result is stored in the out matrix + * @{ + */ +template +void power(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must be same size."); + detail::power(in.data_handle(), out.data_handle(), in.size(), handle.get_stream()); +} + +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/print.cuh b/cpp/include/raft/matrix/print.cuh new file mode 100644 index 0000000000..4d3a8ca938 --- /dev/null +++ b/cpp/include/raft/matrix/print.cuh @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Prints the data stored in GPU memory + * @tparam m_t type of matrix elements + * @tparam idx_t integer type used for indexing + * @param[in] handle: raft handle + * @param[in] in: input matrix + * @param[in] separators: horizontal and vertical separator characters + */ +template +void print(const raft::handle_t& handle, + raft::device_matrix_view in, + print_separators& separators) +{ + detail::print(in.data_handle(), + in.extent(0), + in.extent(1), + separators.horizontal, + separators.vertical, + handle.get_stream()); +} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/print.hpp b/cpp/include/raft/matrix/print.hpp new file mode 100644 index 0000000000..86c314ed44 --- /dev/null +++ b/cpp/include/raft/matrix/print.hpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Prints the data stored in CPU memory + * @param[in] in: input matrix with column-major layout + * @param[in] separators: horizontal and vertical separator characters + */ +template +void print(raft::host_matrix_view in, print_separators& separators) +{ + detail::printHost( + in.data_handle(), in.extent(0), in.extent(1), separators.horizontal, separators.vertical); +} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/ratio.cuh b/cpp/include/raft/matrix/ratio.cuh new file mode 100644 index 0000000000..7895ea972f --- /dev/null +++ b/cpp/include/raft/matrix/ratio.cuh @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::matrix { + +/** + * @brief ratio of every element over sum of input vector is calculated + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle + * @param[in] src: input matrix + * @param[out] dest: output matrix. The result is stored in the dest matrix + */ +template +void ratio(const raft::handle_t& handle, + raft::device_matrix_view src, + raft::device_matrix_view dest) +{ + RAFT_EXPECTS(src.size() == dest.size(), "Input and output matrices must be the same size."); + detail::ratio(handle, src.data_handle(), dest.data_handle(), src.size(), handle.get_stream()); +} + +/** + * @brief ratio of every element over sum of input vector is calculated + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle + * @param[inout] inout: input matrix + */ +template +void ratio(const raft::handle_t& handle, raft::device_matrix_view inout) +{ + detail::ratio( + handle, inout.data_handle(), inout.data_handle(), inout.size(), handle.get_stream()); +} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/reciprocal.cuh b/cpp/include/raft/matrix/reciprocal.cuh new file mode 100644 index 0000000000..c41ecfb999 --- /dev/null +++ b/cpp/include/raft/matrix/reciprocal.cuh @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Reciprocal of every element in the input matrix + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @param handle: raft handle + * @param in: input matrix and also the result is stored + * @param out: output matrix. The result is stored in the out matrix + * @param scalar: every element is multiplied with scalar + * @param setzero round down to zero if the input is less the threshold + * @param thres the threshold used to forcibly set inputs to zero + * @{ + */ +template +void reciprocal(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::host_scalar_view scalar, + bool setzero = false, + math_t thres = 1e-15) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have the same size."); + detail::reciprocal(in.data_handle(), + out.data_handle(), + *(scalar.data_handle()), + in.size(), + handle.get_stream(), + setzero, + thres); +} + +/** + * @brief Reciprocal of every element in the input matrix (in place) + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle to manage resources + * @param[inout] inout: input matrix with in-place results + * @param[in] scalar: every element is multiplied with scalar + * @param[in] setzero round down to zero if the input is less the threshold + * @param[in] thres the threshold used to forcibly set inputs to zero + * @{ + */ +template +void reciprocal(const raft::handle_t& handle, + raft::device_matrix_view inout, + raft::host_scalar_view scalar, + bool setzero = false, + math_t thres = 1e-15) +{ + detail::reciprocal(inout.data_handle(), + *(scalar.data_handle()), + inout.size(), + handle.get_stream(), + setzero, + thres); +} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/sign_flip.cuh b/cpp/include/raft/matrix/sign_flip.cuh new file mode 100644 index 0000000000..01f8829c85 --- /dev/null +++ b/cpp/include/raft/matrix/sign_flip.cuh @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief sign flip stabilizes the sign of col major eigen vectors. + * The sign is flipped if the column has negative |max|. + * @tparam math_t floating point type used for matrix elements + * @tparam idx_t integer type used for indexing + * @param[in] handle: raft handle + * @param[inout] inout: input matrix. Result also stored in this parameter + */ +template +void sign_flip(const raft::handle_t& handle, + raft::device_matrix_view inout) +{ + detail::signFlip(inout.data_handle(), inout.extent(0), inout.extent(1), handle.get_stream()); +} +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/sqrt.cuh b/cpp/include/raft/matrix/sqrt.cuh new file mode 100644 index 0000000000..302167480e --- /dev/null +++ b/cpp/include/raft/matrix/sqrt.cuh @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft::matrix { + +/** + * @brief Square root of every element in the input matrix + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle + * @param[in] in: input matrix and also the result is stored + * @param[out] out: output matrix. The result is stored in the out matrix + */ +template +void sqrt(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); + detail::seqRoot(in.data_handle(), out.data_handle(), in.size(), handle.get_stream()); +} + +/** + * @brief Square root of every element in the input matrix (in place) + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle + * @param[inout] inout: input matrix with in-place results + */ +template +void sqrt(const raft::handle_t& handle, raft::device_matrix_view inout) +{ + detail::seqRoot(inout.data_handle(), inout.size(), handle.get_stream()); +} + +/** + * @brief Square root of every element in the input matrix + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle + * @param[in] in: input matrix and also the result is stored + * @param[out] out: output matrix. The result is stored in the out matrix + * @param[in] scalar: every element is multiplied with scalar + * @param[in] set_neg_zero whether to set negative numbers to zero + */ +template +void weighted_sqrt(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + raft::host_scalar_view scalar, + bool set_neg_zero = false) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size."); + detail::seqRoot(in.data_handle(), + out.data_handle(), + *(scalar.data_handle()), + in.size(), + handle.get_stream(), + set_neg_zero); +} + +/** + * @brief Square root of every element in the input matrix (in place) + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param[in] handle: raft handle + * @param[inout] inout: input matrix and also the result is stored + * @param[in] scalar: every element is multiplied with scalar + * @param[in] set_neg_zero whether to set negative numbers to zero + */ +template +void weighted_sqrt(const raft::handle_t& handle, + raft::device_matrix_view inout, + raft::host_scalar_view scalar, + bool set_neg_zero = false) +{ + detail::seqRoot( + inout.data_handle(), *(scalar.data_handle()), inout.size(), handle.get_stream(), set_neg_zero); +} + +} // namespace raft::matrix diff --git a/cpp/include/raft/matrix/threshold.cuh b/cpp/include/raft/matrix/threshold.cuh new file mode 100644 index 0000000000..7540ceb3c6 --- /dev/null +++ b/cpp/include/raft/matrix/threshold.cuh @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::matrix { + +/** + * @brief sets the small values to zero based on a defined threshold + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param handle: raft handle + * @param[in] in: input matrix + * @param[out] out: output matrix. The result is stored in the out matrix + * @param[in] thres threshold to set values to zero + */ +template +void zero_small_values(const raft::handle_t& handle, + raft::device_matrix_view in, + raft::device_matrix_view out, + math_t thres = 1e-15) +{ + RAFT_EXPECTS(in.size() == out.size(), "Input and output matrices must have same size"); + detail::setSmallValuesZero( + out.data_handle(), in.data_handle(), in.size(), handle.get_stream(), thres); +} + +/** + * @brief sets the small values to zero in-place based on a defined threshold + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @tparam layout layout of the matrix data (must be row or col major) + * @param handle: raft handle + * @param inout: input matrix and also the result is stored + * @param thres: threshold + */ +template +void zero_small_values(const raft::handle_t& handle, + raft::device_matrix_view inout, + math_t thres = 1e-15) +{ + detail::setSmallValuesZero(inout.data_handle(), inout.size(), handle.get_stream(), thres); +} +} // namespace raft::matrix diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index 714b019fba..9cb9b573b1 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -203,7 +203,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, */ template void rbc_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, int_t k, const value_t* query, int_t n_query_pts, @@ -272,7 +272,7 @@ void rbc_knn_query(const raft::handle_t& handle, */ template void rbc_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, raft::device_matrix_view query, raft::device_matrix_view inds, raft::device_matrix_view dists, @@ -289,7 +289,7 @@ void rbc_knn_query(const raft::handle_t& handle, "Number of rows in output indices and distances matrices must equal number of rows " "in search matrix."); - RAFT_EXPECTS(query.extent(1) == index.get_R().extent(1), + RAFT_EXPECTS(query.extent(1) == index.get_X().extent(1), "Number of columns in query and index matrices must match."); rbc_knn_query(handle, @@ -311,4 +311,4 @@ void rbc_knn_query(const raft::handle_t& handle, } // namespace spatial } // namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/spatial/knn/ball_cover_types.hpp b/cpp/include/raft/spatial/knn/ball_cover_types.hpp index 1dd45365b7..897bb4df5b 100644 --- a/cpp/include/raft/spatial/knn/ball_cover_types.hpp +++ b/cpp/include/raft/spatial/knn/ball_cover_types.hpp @@ -58,13 +58,12 @@ class BallCoverIndex { * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m) */ n_landmarks(sqrt(m_)), - R_indptr(std::move(raft::make_device_vector(handle, sqrt(m_) + 1))), - R_1nn_cols(std::move(raft::make_device_vector(handle, m_))), - R_1nn_dists(std::move(raft::make_device_vector(handle, m_))), - R_closest_landmark_dists( - std::move(raft::make_device_vector(handle, m_))), - R(std::move(raft::make_device_matrix(handle, sqrt(m_), n_))), - R_radius(std::move(raft::make_device_vector(handle, sqrt(m_)))), + R_indptr(raft::make_device_vector(handle, sqrt(m_) + 1)), + R_1nn_cols(raft::make_device_vector(handle, m_)), + R_1nn_dists(raft::make_device_vector(handle, m_)), + R_closest_landmark_dists(raft::make_device_vector(handle, m_)), + R(raft::make_device_matrix(handle, sqrt(m_), n_)), + R_radius(raft::make_device_vector(handle, sqrt(m_))), index_trained(false) { } @@ -83,20 +82,41 @@ class BallCoverIndex { * Total memory footprint of index: (2 * sqrt(m)) + (n * sqrt(m)) + (2 * m) */ n_landmarks(sqrt(X_.extent(0))), - R_indptr( - std::move(raft::make_device_vector(handle, sqrt(X_.extent(0)) + 1))), - R_1nn_cols(std::move(raft::make_device_vector(handle, X_.extent(0)))), - R_1nn_dists(std::move(raft::make_device_vector(handle, X_.extent(0)))), - R_closest_landmark_dists( - std::move(raft::make_device_vector(handle, X_.extent(0)))), - R(std::move( - raft::make_device_matrix(handle, sqrt(X_.extent(0)), X_.extent(1)))), - R_radius( - std::move(raft::make_device_vector(handle, sqrt(X_.extent(0))))), + R_indptr(raft::make_device_vector(handle, sqrt(X_.extent(0)) + 1)), + R_1nn_cols(raft::make_device_vector(handle, X_.extent(0))), + R_1nn_dists(raft::make_device_vector(handle, X_.extent(0))), + R_closest_landmark_dists(raft::make_device_vector(handle, X_.extent(0))), + R(raft::make_device_matrix(handle, sqrt(X_.extent(0)), X_.extent(1))), + R_radius(raft::make_device_vector(handle, sqrt(X_.extent(0)))), index_trained(false) { } + auto get_R_indptr() const -> raft::device_vector_view + { + return R_indptr.view(); + } + auto get_R_1nn_cols() const -> raft::device_vector_view + { + return R_1nn_cols.view(); + } + auto get_R_1nn_dists() const -> raft::device_vector_view + { + return R_1nn_dists.view(); + } + auto get_R_radius() const -> raft::device_vector_view + { + return R_radius.view(); + } + auto get_R() const -> raft::device_matrix_view + { + return R.view(); + } + auto get_R_closest_landmark_dists() const -> raft::device_vector_view + { + return R_closest_landmark_dists.view(); + } + raft::device_vector_view get_R_indptr() { return R_indptr.view(); } raft::device_vector_view get_R_1nn_cols() { return R_1nn_cols.view(); } raft::device_vector_view get_R_1nn_dists() { return R_1nn_dists.view(); } @@ -106,8 +126,11 @@ class BallCoverIndex { { return R_closest_landmark_dists.view(); } - raft::device_matrix_view get_X() { return X; } + raft::device_matrix_view get_X() const { return X; } + + raft::distance::DistanceType get_metric() const { return metric; } + value_int get_n_landmarks() const { return n_landmarks; } bool is_index_trained() const { return index_trained; }; // This should only be set by internal functions @@ -115,9 +138,9 @@ class BallCoverIndex { const raft::handle_t& handle; - const value_int m; - const value_int n; - const value_int n_landmarks; + value_int m; + value_int n; + value_int n_landmarks; raft::device_matrix_view X; diff --git a/cpp/include/raft/spatial/knn/brute_force.cuh b/cpp/include/raft/spatial/knn/brute_force.cuh index c32a33d2e2..dda1e02eed 100644 --- a/cpp/include/raft/spatial/knn/brute_force.cuh +++ b/cpp/include/raft/spatial/knn/brute_force.cuh @@ -42,7 +42,6 @@ namespace raft::spatial::knn { * @param[out] out_keys matrix of output keys (size n_samples * k) * @param[out] out_values matrix of output values (size n_samples * k) * @param[in] n_samples number of rows in each part - * @param[in] k number of neighbors for each part * @param[in] translations optional vector of starting index mappings for each partition */ template @@ -53,7 +52,6 @@ inline void knn_merge_parts( raft::device_matrix_view out_keys, raft::device_matrix_view out_values, size_t n_samples, - int k, std::optional> translations = std::nullopt) { RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), @@ -61,7 +59,7 @@ inline void knn_merge_parts( RAFT_EXPECTS( out_keys.extent(0) == out_values.extent(0) == n_samples, "Number of rows in output keys and val matrices must equal number of rows in search matrix."); - RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == k, + RAFT_EXPECTS(out_keys.extent(1) == out_values.extent(1) == in_keys.extent(1), "Number of columns in output indices and distances matrices must be equal to k"); auto n_parts = in_keys.extent(0) / n_samples; @@ -71,7 +69,7 @@ inline void knn_merge_parts( out_values.data_handle(), n_samples, n_parts, - k, + in_keys.extent(1), handle.get_stream(), translations.value_or(nullptr)); } diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index e65a895f60..94897daa22 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -174,15 +174,16 @@ void construct_landmark_1nn(const raft::handle_t& handle, */ template void k_closest_landmarks(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query_pts, value_int n_query_pts, value_int k, value_idx* R_knn_inds, value_t* R_knn_dists) { - std::vector input = {index.get_R().data_handle()}; - std::vector sizes = {index.n_landmarks}; + // TODO: Add const to the brute-force knn inputs + std::vector input = {const_cast(index.get_R().data_handle())}; + std::vector sizes = {index.n_landmarks}; brute_force_knn_impl(handle, input, @@ -196,7 +197,7 @@ void k_closest_landmarks(const raft::handle_t& handle, true, true, nullptr, - index.metric); + index.get_metric()); } /** @@ -240,7 +241,7 @@ template void perform_rbc_query(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query, value_int n_query_pts, std::uint32_t k, @@ -470,7 +471,7 @@ template void rbc_knn_query(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, value_int k, const value_t* query, value_int n_query_pts, diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index c0056e7137..112ab9f13c 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -331,7 +331,7 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, value_idx* out_inds, value_t* out_dists, value_int* dist_counter, - value_t* R_radius, + const value_t* R_radius, distance_func dfunc, float weight = 1.0) { @@ -472,7 +472,7 @@ template void rbc_low_dim_pass_one(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query, const value_int n_query_rows, value_int k, @@ -604,7 +604,7 @@ template void rbc_low_dim_pass_two(const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const value_t* query, const value_int n_query_rows, value_int k, diff --git a/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh b/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh index c859f2c5ec..a861375b2f 100644 --- a/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/specializations/ball_cover.cuh @@ -34,7 +34,7 @@ extern template void rbc_build_index( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, std::uint32_t k, const float* query, std::uint32_t n_query_pts, diff --git a/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp b/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp index afee3bd7a3..31df566b3f 100644 --- a/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp +++ b/cpp/include/raft/spatial/knn/specializations/detail/ball_cover_lowdim.hpp @@ -25,7 +25,7 @@ namespace detail { extern template void rbc_low_dim_pass_one( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, @@ -39,7 +39,7 @@ extern template void rbc_low_dim_pass_one extern template void rbc_low_dim_pass_two( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, @@ -53,7 +53,7 @@ extern template void rbc_low_dim_pass_two extern template void rbc_low_dim_pass_one( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, @@ -67,7 +67,7 @@ extern template void rbc_low_dim_pass_one extern template void rbc_low_dim_pass_two( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, diff --git a/cpp/include/raft/util/input_validation.hpp b/cpp/include/raft/util/input_validation.hpp index b34843f5e8..ab5264f900 100644 --- a/cpp/include/raft/util/input_validation.hpp +++ b/cpp/include/raft/util/input_validation.hpp @@ -42,4 +42,91 @@ constexpr bool is_row_or_column_major(mdspan +constexpr bool is_row_major(mdspan /* m */) +{ + return false; +} + +template +constexpr bool is_row_major(mdspan /* m */) +{ + return false; +} + +template +constexpr bool is_row_major(mdspan /* m */) +{ + return true; +} + +template +constexpr bool is_row_major(mdspan m) +{ + return m.is_exhaustive() && m.stride(1) == typename Extents::index_type(1); +} + +template +constexpr bool is_col_major(mdspan /* m */) +{ + return false; +} + +template +constexpr bool is_col_major(mdspan /* m */) +{ + return true; +} + +template +constexpr bool is_col_major(mdspan /* m */) +{ + return false; +} + +template +constexpr bool is_col_major(mdspan m) +{ + return m.is_exhaustive() && m.stride(0) == typename Extents::index_type(1); +} + +template +constexpr bool is_matrix_view( + mdspan, Layout, Accessor> /* m */) +{ + return sizeof...(Exts) == 2; +} + +template +constexpr bool is_matrix_view(mdspan m) +{ + return false; +} + +template +constexpr bool is_vector_view( + mdspan, Layout, Accessor> /* m */) +{ + return sizeof...(Exts) == 1; +} + +template +constexpr bool is_vector_view(mdspan m) +{ + return false; +} + +template +constexpr bool is_scalar_view( + mdspan, Layout, Accessor> /* m */) +{ + return sizeof...(Exts) == 0; +} + +template +constexpr bool is_scalar_view(mdspan m) +{ + return false; +} + }; // end namespace raft \ No newline at end of file diff --git a/cpp/src/nn/specializations/ball_cover.cu b/cpp/src/nn/specializations/ball_cover.cu index 7473b65d25..15af9f6e68 100644 --- a/cpp/src/nn/specializations/ball_cover.cu +++ b/cpp/src/nn/specializations/ball_cover.cu @@ -37,7 +37,7 @@ template void rbc_build_index template void rbc_knn_query( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, std::uint32_t k, const float* query, std::uint32_t n_query_pts, diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu index 8950ff8d5c..d2d729a52d 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_2d.cu @@ -25,7 +25,7 @@ namespace detail { template void rbc_low_dim_pass_one( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu index 7b8b6ce9a2..0b32d43ba9 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_one_3d.cu @@ -25,7 +25,7 @@ namespace detail { template void rbc_low_dim_pass_one( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, @@ -39,7 +39,7 @@ template void rbc_low_dim_pass_one( template void rbc_low_dim_pass_two( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu index 29e8eec8c8..7c8f18859f 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_2d.cu @@ -25,7 +25,7 @@ namespace detail { template void rbc_low_dim_pass_two( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, diff --git a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu index d6d4b356c8..1ef071033c 100644 --- a/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu +++ b/cpp/src/nn/specializations/detail/ball_cover_lowdim_pass_two_3d.cu @@ -25,7 +25,7 @@ namespace detail { template void rbc_low_dim_pass_two( const raft::handle_t& handle, - BallCoverIndex& index, + const BallCoverIndex& index, const float* query, const std::uint32_t n_query_rows, std::uint32_t k, diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index d63c907dcc..0c9b721294 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -82,6 +82,8 @@ if(BUILD_TESTS) PATH test/cluster/kmeans.cu test/cluster_solvers.cu + test/sparse/linkage.cu + OPTIONAL DIST NN ) ConfigureTest(NAME CORE_TEST @@ -212,7 +214,6 @@ if(BUILD_TESTS) test/sparse/connect_components.cu test/sparse/knn.cu test/sparse/knn_graph.cu - test/sparse/linkage.cu OPTIONAL DIST NN ) diff --git a/cpp/test/matrix/columnSort.cu b/cpp/test/matrix/columnSort.cu index 325ed0204b..aba1c4e1f0 100644 --- a/cpp/test/matrix/columnSort.cu +++ b/cpp/test/matrix/columnSort.cu @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -55,12 +56,11 @@ template class ColumnSort : public ::testing::TestWithParam> { protected: ColumnSort() - : keyIn(0, stream), - keySorted(0, stream), - keySortGolden(0, stream), - valueOut(0, stream), - goldenValOut(0, stream), - workspacePtr(0, stream) + : keyIn(0, handle.get_stream()), + keySorted(0, handle.get_stream()), + keySortGolden(0, handle.get_stream()), + valueOut(0, handle.get_stream()), + goldenValOut(0, handle.get_stream()) { } @@ -68,13 +68,12 @@ class ColumnSort : public ::testing::TestWithParam> { { params = ::testing::TestWithParam>::GetParam(); int len = params.n_row * params.n_col; - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); - keyIn.resize(len, stream); - valueOut.resize(len, stream); - goldenValOut.resize(len, stream); + keyIn.resize(len, handle.get_stream()); + valueOut.resize(len, handle.get_stream()); + goldenValOut.resize(len, handle.get_stream()); if (params.testKeys) { - keySorted.resize(len, stream); - keySortGolden.resize(len, stream); + keySorted.resize(len, handle.get_stream()); + keySortGolden.resize(len, handle.get_stream()); } std::vector vals(len); @@ -97,45 +96,30 @@ class ColumnSort : public ::testing::TestWithParam> { } } - raft::update_device(keyIn.data(), &vals[0], len, stream); - raft::update_device(goldenValOut.data(), &cValGolden[0], len, stream); - - if (params.testKeys) raft::update_device(keySortGolden.data(), &cKeyGolden[0], len, stream); - - bool needWorkspace = false; - size_t workspaceSize = 0; - // Remove this branch once the implementation of descending sort is fixed. - sort_cols_per_row(keyIn.data(), - valueOut.data(), - params.n_row, - params.n_col, - needWorkspace, - NULL, - workspaceSize, - stream, - keySorted.data()); - if (needWorkspace) { - workspacePtr.resize(workspaceSize, stream); - sort_cols_per_row(keyIn.data(), - valueOut.data(), - params.n_row, - params.n_col, - needWorkspace, - workspacePtr.data(), - workspaceSize, - stream, - keySorted.data()); - } - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - RAFT_CUDA_TRY(cudaStreamDestroy(stream)); + raft::update_device(keyIn.data(), &vals[0], len, handle.get_stream()); + raft::update_device(goldenValOut.data(), &cValGolden[0], len, handle.get_stream()); + + if (params.testKeys) + raft::update_device(keySortGolden.data(), &cKeyGolden[0], len, handle.get_stream()); + + auto key_in_view = raft::make_device_matrix_view( + keyIn.data(), params.n_row, params.n_col); + auto value_out_view = raft::make_device_matrix_view( + valueOut.data(), params.n_row, params.n_col); + auto key_sorted_view = raft::make_device_matrix_view( + keySorted.data(), params.n_row, params.n_col); + + raft::matrix::sort_cols_per_row( + handle, key_in_view, value_out_view, std::make_optional(key_sorted_view)); + + RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } protected: - cudaStream_t stream = 0; columnSort params; rmm::device_uvector keyIn, keySorted, keySortGolden; rmm::device_uvector valueOut, goldenValOut; // valueOut are indexes - rmm::device_uvector workspacePtr; + raft::handle_t handle; }; const std::vector> inputsf1 = {{0.000001f, 503, 2000, false}, diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index 61f6bff040..4b3244913b 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -17,6 +17,7 @@ #include "../test_utils.h" #include #include +#include #include #include #include @@ -45,19 +46,6 @@ void naiveGather( naiveGatherImpl(in, D, N, map, map_length, out); } -template -void gatherLaunch(MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, - cudaStream_t stream) -{ - typedef typename std::iterator_traits::value_type MapValueT; - matrix::gather(in, D, N, map, map_length, out, stream); -} - struct GatherInputs { uint32_t nrows; uint32_t ncols; @@ -109,8 +97,18 @@ class GatherTest : public ::testing::TestWithParam { naiveGather(h_in.data(), ncols, nrows, h_map.data(), map_length, h_out.data()); raft::update_device(d_out_exp.data(), h_out.data(), map_length * ncols, stream); - // launch device version of the kernel - gatherLaunch(d_in.data(), ncols, nrows, d_map.data(), map_length, d_out_act.data(), stream); + auto in_view = raft::make_device_matrix_view( + d_in.data(), nrows, ncols); + auto out_view = + raft::make_device_matrix_view(d_out_act.data(), map_length, ncols); + auto map_view = + raft::make_device_vector_view(d_map.data(), map_length); + + raft::matrix::gather(handle, in_view, map_view, out_view); + + // // launch device version of the kernel + // gatherLaunch( + // handle, d_in.data(), ncols, nrows, d_map.data(), map_length, d_out_act.data(), stream); handle.sync_stream(stream); } diff --git a/cpp/test/matrix/linewise_op.cu b/cpp/test/matrix/linewise_op.cu index 16e2ceb29a..9d3d5af51e 100644 --- a/cpp/test/matrix/linewise_op.cu +++ b/cpp/test/matrix/linewise_op.cu @@ -18,9 +18,10 @@ #include "../test_utils.h" #include #include -#include +#include +#include #include -#include +#include #include #include #include @@ -54,23 +55,39 @@ struct LinewiseTest : public ::testing::TestWithParam + void runLinewiseSum(T* out, const T* in, const I lineLen, const I nLines, const T* vec) { - auto f = [] __device__(T a, T b) -> T { return a + b; }; - matrix::linewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec); + auto f = [] __device__(T a, T b) -> T { return a + b; }; + constexpr auto rowmajor = std::is_same_v; + + I m = rowmajor ? lineLen : nLines; + I n = rowmajor ? nLines : lineLen; + + auto in_view = raft::make_device_matrix_view(in, m, n); + auto out_view = raft::make_device_matrix_view(out, m, n); + + auto vec_view = raft::make_device_vector_view(vec, m); + matrix::linewise_op(handle, in_view, out_view, raft::is_row_major(in_view), f, vec_view); } - void runLinewiseSum(T* out, - const T* in, - const I lineLen, - const I nLines, - const bool alongLines, - const T* vec1, - const T* vec2) + template + void runLinewiseSum( + T* out, const T* in, const I lineLen, const I nLines, const T* vec1, const T* vec2) { - auto f = [] __device__(T a, T b, T c) -> T { return a + b + c; }; - matrix::linewiseOp(out, in, lineLen, nLines, alongLines, f, stream, vec1, vec2); + auto f = [] __device__(T a, T b, T c) -> T { return a + b + c; }; + constexpr auto rowmajor = std::is_same_v; + + I m = rowmajor ? lineLen : nLines; + I n = rowmajor ? nLines : lineLen; + + auto in_view = raft::make_device_matrix_view(in, m, n); + auto out_view = raft::make_device_matrix_view(out, m, n); + auto vec1_view = raft::make_device_vector_view(vec1, m); + auto vec2_view = raft::make_device_vector_view(vec2, m); + + matrix::linewise_op( + handle, in_view, out_view, raft::is_row_major(in_view), f, vec1_view, vec2_view); } rmm::device_uvector genData(size_t workSizeBytes) @@ -149,7 +166,11 @@ struct LinewiseTest : public ::testing::TestWithParam(out, in, lineLen, nLines, vec1); + } else { + runLinewiseSum(out, in, lineLen, nLines, vec1); + } } if (params.checkCorrectness) { linalg::naiveMatVec( @@ -161,7 +182,12 @@ struct LinewiseTest : public ::testing::TestWithParam(out, in, lineLen, nLines, vec1, vec2); + + } else { + runLinewiseSum(out, in, lineLen, nLines, vec1, vec2); + } } if (params.checkCorrectness) { linalg::naiveMatVec( diff --git a/cpp/test/matrix/math.cu b/cpp/test/matrix/math.cu index d550852150..ad4a37825c 100644 --- a/cpp/test/matrix/math.cu +++ b/cpp/test/matrix/math.cu @@ -16,7 +16,15 @@ #include "../test_utils.h" #include -#include + +#include +#include +#include +#include +#include +#include +#include + #include #include @@ -43,7 +51,7 @@ template __global__ void nativeSqrtKernel(Type* in, Type* out, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < len) { out[idx] = sqrt(in[idx]); } + if (idx < len) { out[idx] = std::sqrt(in[idx]); } } template @@ -147,16 +155,24 @@ class MathTest : public ::testing::TestWithParam> { uniform(handle, r, in_sign_flip.data(), len, T(-100.0), T(100.0)); naivePower(in_power.data(), out_power_ref.data(), len, stream); - power(in_power.data(), len, stream); + + auto in_power_view = raft::make_device_matrix_view(in_power.data(), len, 1); + power(handle, in_power_view); naiveSqrt(in_sqrt.data(), out_sqrt_ref.data(), len, stream); - seqRoot(in_sqrt.data(), len, stream); - ratio(handle, in_ratio.data(), in_ratio.data(), 4, stream); + auto in_sqrt_view = raft::make_device_matrix_view(in_sqrt.data(), len, 1); + sqrt(handle, in_sqrt_view); + + auto in_ratio_view = raft::make_device_matrix_view(in_ratio.data(), 4, 1); + ratio(handle, in_ratio_view); naiveSignFlip( in_sign_flip.data(), out_sign_flip_ref.data(), params.n_row, params.n_col, stream); - signFlip(in_sign_flip.data(), params.n_row, params.n_col, stream); + + auto in_sign_flip_view = raft::make_device_matrix_view( + in_sign_flip.data(), params.n_row, params.n_col); + sign_flip(handle, in_sign_flip_view); // default threshold is 1e-15 std::vector in_recip_h = {0.1, 0.01, -0.01, 0.1e-16}; @@ -165,18 +181,28 @@ class MathTest : public ::testing::TestWithParam> { update_device(in_recip_ref.data(), in_recip_ref_h.data(), 4, stream); T recip_scalar = T(1.0); + auto in_recip_view = raft::make_device_matrix_view(in_recip.data(), 4, 1); + auto out_recip_view = raft::make_device_matrix_view(out_recip.data(), 4, 1); + // this `reciprocal()` has to go first bc next one modifies its input - reciprocal(in_recip.data(), out_recip.data(), recip_scalar, 4, stream); + reciprocal( + handle, in_recip_view, out_recip_view, raft::make_host_scalar_view(&recip_scalar)); - reciprocal(in_recip.data(), recip_scalar, 4, stream, true); + auto inout_recip_view = raft::make_device_matrix_view(in_recip.data(), 4, 1); + + reciprocal(handle, inout_recip_view, raft::make_host_scalar_view(&recip_scalar), true); std::vector in_small_val_zero_h = {0.1, 1e-16, -1e-16, -0.1}; std::vector in_small_val_zero_ref_h = {0.1, 0.0, 0.0, -0.1}; + auto in_smallzero_view = raft::make_device_matrix_view(in_smallzero.data(), 4, 1); + auto inout_smallzero_view = raft::make_device_matrix_view(in_smallzero.data(), 4, 1); + auto out_smallzero_view = raft::make_device_matrix_view(out_smallzero.data(), 4, 1); + update_device(in_smallzero.data(), in_small_val_zero_h.data(), 4, stream); update_device(out_smallzero_ref.data(), in_small_val_zero_ref_h.data(), 4, stream); - setSmallValuesZero(out_smallzero.data(), in_smallzero.data(), 4, stream); - setSmallValuesZero(in_smallzero.data(), 4, stream); + zero_small_values(handle, in_smallzero_view, out_smallzero_view); + zero_small_values(handle, inout_smallzero_view); handle.sync_stream(stream); } diff --git a/cpp/test/matrix/matrix.cu b/cpp/test/matrix/matrix.cu index 6ccd7aa335..78391d5ff2 100644 --- a/cpp/test/matrix/matrix.cu +++ b/cpp/test/matrix/matrix.cu @@ -16,7 +16,9 @@ #include "../test_utils.h" #include -#include +#include + +#include #include #include #include @@ -61,12 +63,19 @@ class MatrixTest : public ::testing::TestWithParam> { int len = params.n_row * params.n_col; uniform(handle, r, in1.data(), len, T(-1.0), T(1.0)); - copy(in1.data(), in2.data(), params.n_row, params.n_col, stream); + auto in1_view = raft::make_device_matrix_view( + in1.data(), params.n_row, params.n_col); + auto in2_view = + raft::make_device_matrix_view(in2.data(), params.n_row, params.n_col); + + copy(handle, in1_view, in2_view); // copy(in1, in1_revr, params.n_row, params.n_col); // colReverse(in1_revr, params.n_row, params.n_col); rmm::device_uvector outTrunc(6, stream); - truncZeroOrigin(in1.data(), params.n_row, outTrunc.data(), 3, 2, stream); + + auto out_trunc_view = raft::make_device_matrix_view(outTrunc.data(), 3, 2); + trunc_zero_origin(handle, in1_view, out_trunc_view); handle.sync_stream(stream); } @@ -128,24 +137,25 @@ class MatrixCopyRowsTest : public ::testing::Test { void testCopyRows() { - copyRows(input.data(), - n_rows, - n_cols, - output.data(), - indices.data(), - n_selected, - handle.get_stream(), - false); + auto input_view = raft::make_device_matrix_view( + input.data(), n_rows, n_cols); + auto output_view = raft::make_device_matrix_view( + output.data(), n_selected, n_cols); + + auto indices_view = + raft::make_device_vector_view(indices.data(), n_selected); + + raft::matrix::copy_rows(handle, input_view, output_view, indices_view); + EXPECT_TRUE(raft::devArrMatchHost( output_exp_colmajor, output.data(), n_selected * n_cols, raft::Compare(), stream)); - copyRows(input.data(), - n_rows, - n_cols, - output.data(), - indices.data(), - n_selected, - handle.get_stream(), - true); + + auto input_row_view = raft::make_device_matrix_view( + input.data(), n_rows, n_cols); + auto output_row_view = raft::make_device_matrix_view( + output.data(), n_selected, n_cols); + + raft::matrix::copy_rows(handle, input_row_view, output_row_view, indices_view); EXPECT_TRUE(raft::devArrMatchHost( output_exp_rowmajor, output.data(), n_selected * n_cols, raft::Compare(), stream)); } diff --git a/cpp/test/sparse/linkage.cu b/cpp/test/sparse/linkage.cu index e9df5e3df0..ce5741d06b 100644 --- a/cpp/test/sparse/linkage.cu +++ b/cpp/test/sparse/linkage.cu @@ -24,6 +24,7 @@ #include #endif +#include #include #include @@ -175,23 +176,24 @@ class LinkageTest : public ::testing::TestWithParam> { raft::copy(data.data(), params.data.data(), data.size(), stream); raft::copy(labels_ref.data(), params.expected_labels.data(), params.n_row, stream); - raft::hierarchy::linkage_output out_arrs; - out_arrs.labels = labels.data(); - rmm::device_uvector out_children(params.n_row * 2, stream); - out_arrs.children = out_children.data(); - raft::handle_t handle; - raft::hierarchy::single_linkage( + + auto data_view = + raft::make_device_matrix_view(data.data(), params.n_row, params.n_col); + auto dendrogram_view = + raft::make_device_matrix_view(out_children.data(), params.n_row, 2); + auto labels_view = raft::make_device_vector_view(labels.data(), params.n_row); + + raft::cluster::single_linkage( handle, - data.data(), - params.n_row, - params.n_col, + data_view, + dendrogram_view, + labels_view, raft::distance::DistanceType::L2SqrtExpanded, - &out_arrs, - params.c, - params.n_clusters); + params.n_clusters, + std::make_optional(params.c)); handle.sync_stream(stream);