diff --git a/README.md b/README.md index a79679c579..54dd394a69 100755 --- a/README.md +++ b/README.md @@ -33,37 +33,51 @@ The Python API is being improved to wrap the algorithms and primitives from the ## Getting started ### Rapids Memory Manager (RMM) -RAFT relies heavily on [RMM](https://github.com/rapidsai/rmm) which, -like other projects in the RAPIDS ecosystem, eases the burden of configuring different allocation strategies globally -across the libraries that use it. RMM also provides [RAII](https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization)) wrappers around device arrays that handle the allocation and cleanup. + +RAFT relies heavily on RMM which, like other projects in the RAPIDS ecosystem, eases the burden of configuring different allocation strategies globally across the libraries that use it. + +### Multi-dimensional Arrays + +The APIs in RAFT currently accept raw pointers to device memory and we are in the process of simplifying the APIs with the [mdspan](https://arxiv.org/abs/2010.06474) multi-dimensional array view for representing data in higher dimensions similar to the `ndarray` in the Numpy Python library. RAFT also contains the corresponding owning `mdarray` structure, which simplifies the allocation and management of multi-dimensional data in both host and device (GPU) memory. + +The `mdarray` forms a convenience layer over RMM and can be constructed in RAFT using a number of different helper functions: + +```c++ +#include + +int n_rows = 10; +int n_cols = 10; + +auto scalar = raft::make_device_scalar(handle, 1.0); +auto vector = raft::make_device_vector(handle, n_cols); +auto matrix = raft::make_device_matrix(handle, n_rows, n_cols); +``` ### C++ Example Most of the primitives in RAFT accept a `raft::handle_t` object for the management of resources which are expensive to create, such CUDA streams, stream pools, and handles to other CUDA libraries like `cublas` and `cusolver`. -The example below demonstrates creating a RAFT handle and using it with RMM's `device_uvector` to allocate memory on device and compute +The example below demonstrates creating a RAFT handle and using it with `device_matrix` and `device_vector` to allocate memory, generating random clusters, and computing pairwise Euclidean distances: ```c++ #include -#include +#include +#include +#include -#include raft::handle_t handle; -int n_samples = ...; -int n_features = ...; +int n_samples = 5000; +int n_features = 50; -rmm::device_uvector input(n_samples * n_features, handle.get_stream()); -rmm::device_uvector output(n_samples * n_samples, handle.get_stream()); +auto input = raft::make_device_matrix(handle, n_samples, n_features); +auto labels = raft::make_device_vector(handle, n_samples); +auto output = raft::make_device_matrix(handle, n_samples, n_samples); -// ... Populate feature matrix ... +raft::random::make_blobs(handle, input, labels); auto metric = raft::distance::DistanceType::L2SqrtExpanded; -rmm::device_uvector workspace(0, handle.get_stream()); -raft::distance::pairwise_distance(handle, input.data(), input.data(), - output.data(), - n_samples, n_samples, n_features, - workspace.data(), metric); +raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); ``` ## Installing @@ -159,3 +173,26 @@ The folder structure mirrors other RAPIDS repos (cuDF, cuML, cuGraph...), with t ## Contributing If you are interested in contributing to the RAFT project, please read our [Contributing guidelines](CONTRIBUTING.md). Refer to the [Developer Guide](DEVELOPER_GUIDE.md) for details on the developer guidelines, workflows, and principals. + +## References + +When citing RAFT generally, please consider referencing this Github project. +```bibtex +@misc{rapidsai, + title={Rapidsai/raft: RAFT contains fundamental widely-used algorithms and primitives for data science, Graph and machine learning.}, + url={https://github.com/rapidsai/raft}, + journal={GitHub}, + publisher={Nvidia RAPIDS}, + author={Rapidsai}, + year={2022} +} +``` +If citing the sparse pairwise distances API, please consider using the following bibtex: +```bibtex +@article{nolet2021semiring, + title={Semiring primitives for sparse neighborhood methods on the gpu}, + author={Nolet, Corey J and Gala, Divye and Raff, Edward and Eaton, Joe and Rees, Brad and Zedlewski, John and Oates, Tim}, + journal={arXiv preprint arXiv:2104.06357}, + year={2021} +} +``` \ No newline at end of file diff --git a/cpp/include/raft.hpp b/cpp/include/raft.hpp index 08f836d3a8..f942692aeb 100644 --- a/cpp/include/raft.hpp +++ b/cpp/include/raft.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,6 +14,10 @@ * limitations under the License. */ +#include "raft/handle.hpp" +#include "raft/mdarray.hpp" +#include "raft/span.hpp" + #include namespace raft { diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 71c9e8d32b..e13cfd94f8 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -23,6 +23,8 @@ #include #include +#include + namespace raft { namespace distance { @@ -144,6 +146,35 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In return detail::getWorkspaceSize(x, y, m, n, k); } +/** + * @brief Return the exact workspace size to compute the distance + * @tparam DistanceType which distance to evaluate + * @tparam InType input argument type + * @tparam AccType accumulation type + * @tparam OutType output type + * @tparam Index_ Index type + * @param x first set of points (size m*k) + * @param y second set of points (size n*k) + * @return number of bytes needed in workspace + * + * @note If the specified distanceType doesn't need the workspace at all, it + * returns 0. + */ +template +size_t getWorkspaceSize(const raft::device_matrix_view x, + const raft::device_matrix_view y) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + + return getWorkspaceSize( + x.data(), y.data(), x.extent(0), y.extent(0), x.extent(1)); +} + /** * @brief Evaluate pairwise distances for the simple use case * @tparam DistanceType which distance to evaluate @@ -160,9 +191,6 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) - * - * @note if workspace is passed as nullptr, this will return in - * worksize, the number of bytes of workspace required */ template +void distance(raft::handle_t const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + InType metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_contiguous(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_contiguous(), "Input y must be contiguous."); + + auto is_rowmajor = std::is_same::value; + + distance(x.data(), + y.data(), + dist.data(), + x.extent(0), + y.extent(0), + x.extent(1), + handle.get_stream(), + is_rowmajor, + metric_arg); +} + /** * @defgroup pairwise_distance pairwise distance prims * @{ @@ -319,6 +399,58 @@ void pairwise_distance(const raft::handle_t& handle, handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); } +/** + * @defgroup pairwise_distance pairwise distance prims + * @{ + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam Index_ indexing type + * @param x first matrix of points (size mxk) + * @param y second matrix of points (size nxk) + * @param dist output distance matrix (size mxn) + * @param workspace temporary workspace buffer which can get resized as per the + * needed workspace size + * @param metric distance metric + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + */ +template +void pairwise_distance(raft::handle_t const& handle, + device_matrix_view const x, + device_matrix_view const y, + device_matrix_view dist, + raft::distance::DistanceType metric, + Type metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_contiguous(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_contiguous(), "Input y must be contiguous."); + RAFT_EXPECTS(dist.is_contiguous(), "Output must be contiguous."); + + bool rowmajor = x.stride(0) == 0; + + rmm::device_uvector workspace(0, handle.get_stream()); + + pairwise_distance(handle, + x.data(), + y.data(), + dist.data(), + x.extent(0), + y.extent(0), + x.extent(1), + metric, + rowmajor, + metric_arg); +} + }; // namespace distance }; // namespace raft diff --git a/cpp/include/raft/distance/distance.hpp b/cpp/include/raft/distance/distance.hpp index f9fbde50e4..66b4efcede 100644 --- a/cpp/include/raft/distance/distance.hpp +++ b/cpp/include/raft/distance/distance.hpp @@ -28,6 +28,8 @@ #include #include +#include + namespace raft { namespace distance { @@ -149,6 +151,34 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In return detail::getWorkspaceSize(x, y, m, n, k); } +/** + * @brief Return the exact workspace size to compute the distance + * @tparam DistanceType which distance to evaluate + * @tparam InType input argument type + * @tparam AccType accumulation type + * @tparam OutType output type + * @tparam Index_ Index type + * @param x first set of points (size m*k) + * @param y second set of points (size n*k) + * @return number of bytes needed in workspace + * + * @note If the specified distanceType doesn't need the workspace at all, it + * returns 0. + */ +template +size_t getWorkspaceSize(const raft::device_matrix_view& x, + const raft::device_matrix_view& y) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + + return getWorkspaceSize( + x.data(), y.data(), x.extent(0), y.extent(0), x.extent(1)); +} + /** * @brief Evaluate pairwise distances for the simple use case * @tparam DistanceType which distance to evaluate @@ -165,9 +195,6 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) - * - * @note if workspace is passed as nullptr, this will return in - * worksize, the number of bytes of workspace required */ template +void distance(raft::handle_t const handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + InType metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_contiguous(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_contiguous(), "Input y must be contiguous."); + + if (x.stride(0) == 0 && y.stride(0) == 0) { + distance(x.data(), + y.data(), + dist.data(), + x.extent(0), + y.extent(0), + x.extent(1), + handle.get_stream(), + true, + metric_arg); + } else if (x.stride(0) > 0 && y.stride(0) > 0) { + distance(x.data(), + y.data(), + dist.data(), + x.extent(0), + y.extent(0), + x.extent(1), + handle.get_stream(), + false, + metric_arg); + } else { + RAFT_FAIL("x and y must both have the same layout: row-major or column-major."); + } +} + +/** * @brief Convenience wrapper around 'distance' prim to convert runtime metric * into compile time for the purpose of dispatch * @tparam Type input/accumulation/output data-type * @tparam Index_ indexing type + * @param handle raft handle for managing expensive resources * @param x first set of points * @param y second set of points * @param dist output distance matrix @@ -207,8 +298,8 @@ void distance(const InType* x, * @param workspace temporary workspace buffer which can get resized as per the * needed workspace size * @param metric distance metric - * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument */ template void pairwise_distance(const raft::handle_t& handle, @@ -288,15 +379,13 @@ void pairwise_distance(const raft::handle_t& handle, default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; } -/** @} */ /** - * @defgroup pairwise_distance pairwise distance prims - * @{ * @brief Convenience wrapper around 'distance' prim to convert runtime metric * into compile time for the purpose of dispatch * @tparam Type input/accumulation/output data-type * @tparam Index_ indexing type + * @param handle raft handle for managing expensive resources * @param x first set of points * @param y second set of points * @param dist output distance matrix @@ -304,8 +393,8 @@ void pairwise_distance(const raft::handle_t& handle, * @param n number of points in y * @param k dimensionality * @param metric distance metric - * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument */ template void pairwise_distance(const raft::handle_t& handle, @@ -324,6 +413,54 @@ void pairwise_distance(const raft::handle_t& handle, handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); } +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam Index_ indexing type + * @param handle raft handle for managing expensive resources + * @param x first matrix of points (size mxk) + * @param y second matrix of points (size nxk) + * @param dist output distance matrix (size mxn) + * @param metric distance metric + * @param metric_arg metric argument + */ +template +void pairwise_distance(raft::handle_t const& handle, + device_matrix_view const& x, + device_matrix_view const& y, + device_matrix_view& dist, + raft::distance::DistanceType metric, + Type metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_contiguous(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_contiguous(), "Input y must be contiguous."); + RAFT_EXPECTS(dist.is_contiguous(), "Output must be contiguous."); + + bool rowmajor = x.stride(0) == 0; + + rmm::device_uvector workspace(0, handle.get_stream()); + + pairwise_distance(handle, + x.data(), + y.data(), + dist.data(), + x.extent(0), + y.extent(0), + x.extent(1), + metric, + rowmajor, + metric_arg); +} + }; // namespace distance }; // namespace raft diff --git a/cpp/include/raft/mdarray.hpp b/cpp/include/raft/mdarray.hpp index 44ca526c16..f92a0e5e59 100644 --- a/cpp/include/raft/mdarray.hpp +++ b/cpp/include/raft/mdarray.hpp @@ -23,6 +23,7 @@ #pragma once #include #include +#include #include namespace raft { @@ -295,6 +296,10 @@ class mdarray { /** * @brief mdarray with host container policy + * @tparam ElementType the data type of the elements + * @tparam Extents defines the shape + * @tparam LayoutPolicy policy for indexing strides and layout ordering + * @tparam ContainerPolicy storage and accessor policy */ template using host_scalar = host_mdarray; /** * @brief Shorthand for 0-dim host mdarray (scalar). - * - * Similar to rmm::device_scalar, underying storage is rmm::device_uvector. + * @tparam ElementType the data type of the scalar element */ template using device_scalar = device_mdarray; /** * @brief Shorthand for 1-dim host mdarray. + * @tparam ElementType the data type of the vector elements */ template using host_vector = host_mdarray; /** * @brief Shorthand for 1-dim device mdarray. + * @tparam ElementType the data type of the vector elements */ template using device_vector = device_mdarray; /** * @brief Shorthand for c-contiguous host matrix. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering */ template using host_matrix = host_mdarray; /** * @brief Shorthand for c-contiguous device matrix. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering */ template using device_matrix = device_mdarray; /** * @brief Shorthand for 0-dim host mdspan (scalar). + * @tparam ElementType the data type of the scalar element */ template using host_scalar_view = host_mdspan; /** * @brief Shorthand for 0-dim host mdspan (scalar). + * @tparam ElementType the data type of the scalar element */ template using device_scalar_view = device_mdspan; /** * @brief Shorthand for 1-dim host mdspan. + * @tparam ElementType the data type of the vector elements */ template using host_vector_view = host_mdspan; /** * @brief Shorthand for 1-dim device mdspan. + * @tparam ElementType the data type of the vector elements */ template using device_vector_view = device_mdspan; /** * @brief Shorthand for c-contiguous host matrix view. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * */ template using host_matrix_view = host_mdspan; + /** * @brief Shorthand for c-contiguous device matrix view. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * */ template using device_matrix_view = device_mdspan; /** * @brief Create a 0-dim (scalar) mdspan instance for host value. + * + * @tparam ElementType the data type of the matrix elements + * @param[in] ptr on device to wrap */ template auto make_host_scalar_view(ElementType* ptr) @@ -400,6 +427,9 @@ auto make_host_scalar_view(ElementType* ptr) /** * @brief Create a 0-dim (scalar) mdspan instance for device value. + * + * @tparam ElementType the data type of the matrix elements + * @param[in] ptr on device to wrap */ template auto make_device_scalar_view(ElementType* ptr) @@ -409,7 +439,14 @@ auto make_device_scalar_view(ElementType* ptr) } /** - * @brief Create a 2-dim c-contiguous mdspan instance for host pointer. + * @brief Create a 2-dim c-contiguous mdspan instance for host pointer. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr on host to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer */ template auto make_host_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols) @@ -418,7 +455,14 @@ auto make_host_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols) return host_matrix_view{ptr, extents}; } /** - * @brief Create a 2-dim c-contiguous mdspan instance for device pointer. + * @brief Create a 2-dim c-contiguous mdspan instance for device pointer. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr on device to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer */ template auto make_device_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols) @@ -429,6 +473,10 @@ auto make_device_matrix_view(ElementType* ptr, size_t n_rows, size_t n_cols) /** * @brief Create a 1-dim mdspan instance for host pointer. + * @tparam ElementType the data type of the vector elements + * @param[in] ptr on host to wrap + * @param[in] n number of elements in pointer + * @return raft::host_vector_view */ template auto make_host_vector_view(ElementType* ptr, size_t n) @@ -439,6 +487,10 @@ auto make_host_vector_view(ElementType* ptr, size_t n) /** * @brief Create a 1-dim mdspan instance for device pointer. + * @tparam ElementType the data type of the vector elements + * @param[in] ptr on device to wrap + * @param[in] n number of elements in pointer + * @return raft::device_vector_view */ template auto make_device_vector_view(ElementType* ptr, size_t n) @@ -449,6 +501,11 @@ auto make_device_vector_view(ElementType* ptr, size_t n) /** * @brief Create a 2-dim c-contiguous host mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] n_rows number or rows in matrix + * @param[in] n_cols number of columns in matrix + * @return raft::host_matrix */ template auto make_host_matrix(size_t n_rows, size_t n_cols) @@ -461,6 +518,12 @@ auto make_host_matrix(size_t n_rows, size_t n_cols) /** * @brief Create a 2-dim c-contiguous device mdarray. + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] n_rows number or rows in matrix + * @param[in] n_cols number of columns in matrix + * @param[in] stream cuda stream for ordering events + * @return raft::device_matrix */ template auto make_device_matrix(size_t n_rows, size_t n_cols, rmm::cuda_stream_view stream) @@ -471,10 +534,28 @@ auto make_device_matrix(size_t n_rows, size_t n_cols, rmm::cuda_stream_view stre return device_matrix{extents, policy}; } +/** + * @brief Create a 2-dim c-contiguous device mdarray. + * + * @tparam ElementType the data type of the matrix elements + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] handle raft handle for managing expensive resources + * @param[in] n_rows number or rows in matrix + * @param[in] n_cols number of columns in matrix + * @return raft::device_matrix + */ +template +auto make_device_matrix(raft::handle_t const& handle, size_t n_rows, size_t n_cols) +{ + return make_device_matrix(n_rows, n_cols, handle.get_stream()); +} + /** * @brief Create a host scalar from v. * - * Underlying storage is std::vector. + * @tparam ElementType the data type of the scalar element + * @param[in] v scalar type to wrap + * @return raft::host_scalar */ template auto make_host_scalar(ElementType const& v) @@ -493,7 +574,10 @@ auto make_host_scalar(ElementType const& v) /** * @brief Create a device scalar from v. * - * Similar to rmm::device_scalar, underying storage is rmm::device_uvector. + * @tparam ElementType the data type of the scalar element + * @param[in] v scalar type to wrap on device + * @param[in] stream the cuda stream for ordering events + * @return raft::device_scalar */ template auto make_device_scalar(ElementType const& v, rmm::cuda_stream_view stream) @@ -506,8 +590,25 @@ auto make_device_scalar(ElementType const& v, rmm::cuda_stream_view stream) return scalar; } +/** + * @brief Create a device scalar from v. + * + * @tparam ElementType the data type of the scalar element + * @param[in] handle raft handle for managing expensive cuda resources + * @param[in] v scalar to wrap on device + * @return raft::device_scalar + */ +template +auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) +{ + return make_device_scalar(v, handle.get_stream()); +} + /** * @brief Create a 1-dim host mdarray. + * @tparam ElementType the data type of the vector elements + * @param[in] n number of elements in vector + * @return raft::host_vector */ template auto make_host_vector(size_t n) @@ -520,6 +621,10 @@ auto make_host_vector(size_t n) /** * @brief Create a 1-dim device mdarray. + * @tparam ElementType the data type of the vector elements + * @param[in] n number of elements in vector + * @param[in] stream the cuda stream for ordering events + * @return raft::device_vector */ template auto make_device_vector(size_t n, rmm::cuda_stream_view stream) @@ -529,4 +634,17 @@ auto make_device_vector(size_t n, rmm::cuda_stream_view stream) policy_t policy{stream}; return device_vector{extents, policy}; } + +/** + * @brief Create a 1-dim device mdarray. + * @tparam ElementType the data type of the vector elements + * @param[in] handle raft handle for managing expensive cuda resources + * @param[in] n number of elements in vector + * @return raft::device_vector + */ +template +auto make_device_vector(raft::handle_t const& handle, size_t n) +{ + return make_device_vector(n, handle.get_stream()); +} } // namespace raft diff --git a/cpp/include/raft/random/make_blobs.cuh b/cpp/include/raft/random/make_blobs.cuh index 2ad3a7960d..088690529a 100644 --- a/cpp/include/raft/random/make_blobs.cuh +++ b/cpp/include/raft/random/make_blobs.cuh @@ -20,6 +20,8 @@ #pragma once #include "detail/make_blobs.cuh" +#include +#include namespace raft::random { @@ -91,6 +93,90 @@ void make_blobs(DataT* out, type); } +/** + * @brief GPU-equivalent of sklearn.datasets.make_blobs + * + * @tparam DataT output data type + * @tparam IdxT indexing arithmetic type + * + * @param[out] out generated data [on device] + * [dim = n_rows x n_cols] + * @param[out] labels labels for the generated data [on device] + * [len = n_rows] + * @param[in] n_rows number of rows in the generated data + * @param[in] n_cols number of columns in the generated data + * @param[in] n_clusters number of clusters (or classes) to generate + * @param[in] stream cuda stream to schedule the work on + * @param[in] row_major whether input `centers` and output `out` + * buffers are to be stored in row or column + * major layout + * @param[in] centers centers of each of the cluster, pass a nullptr + * if you need this also to be generated randomly + * [on device] [dim = n_clusters x n_cols] + * @param[in] cluster_std standard deviation of each cluster center, + * pass a nullptr if this is to be read from the + * `cluster_std_scalar`. [on device] + * [len = n_clusters] + * @param[in] cluster_std_scalar if 'cluster_std' is nullptr, then use this as + * the std-dev across all dimensions. + * @param[in] shuffle shuffle the generated dataset and labels + * @param[in] center_box_min min value of box from which to pick cluster + * centers. Useful only if 'centers' is nullptr + * @param[in] center_box_max max value of box from which to pick cluster + * centers. Useful only if 'centers' is nullptr + * @param[in] seed seed for the RNG + * @param[in] type RNG type + */ +template +void make_blobs(raft::handle_t const& handle, + raft::device_matrix_view out, + raft::device_vector_view labels, + IdxT n_clusters = 5, + std::optional> centers = std::nullopt, + std::optional> const cluster_std = std::nullopt, + const DataT cluster_std_scalar = (DataT)1.0, + bool shuffle = true, + DataT center_box_min = (DataT)-10.0, + DataT center_box_max = (DataT)10.0, + uint64_t seed = 0ULL, + GeneratorType type = GenPhilox) +{ + if (centers.has_value()) { + RAFT_EXPECTS(centers.value().extent(0) == (std::size_t)n_clusters, + "n_centers must equal size of centers"); + } + + if (cluster_std.has_value()) { + RAFT_EXPECTS(cluster_std.value().extent(0) == (std::size_t)n_clusters, + "n_centers must equal size of cluster_std"); + } + + RAFT_EXPECTS(out.extent(0) == labels.extent(0), + "Number of labels must equal the number of row in output matrix"); + + RAFT_EXPECTS(out.is_contiguous(), "Output must be contiguous."); + + bool row_major = std::is_same::value; + + auto prm_centers = centers.has_value() ? centers.value().data() : nullptr; + auto prm_cluster_std = cluster_std.has_value() ? cluster_std.value().data() : nullptr; + + detail::make_blobs_caller(out.data(), + labels.data(), + (IdxT)out.extent(0), + (IdxT)out.extent(1), + n_clusters, + handle.get_stream(), + row_major, + prm_centers, + prm_cluster_std, + cluster_std_scalar, + shuffle, + center_box_min, + center_box_max, + seed, + type); +} } // end namespace raft::random #endif \ No newline at end of file diff --git a/cpp/include/raft/random/make_blobs.hpp b/cpp/include/raft/random/make_blobs.hpp index 19d4b8499b..02aef809e7 100644 --- a/cpp/include/raft/random/make_blobs.hpp +++ b/cpp/include/raft/random/make_blobs.hpp @@ -25,6 +25,8 @@ #pragma once #include "detail/make_blobs.cuh" +#include +#include namespace raft::random { @@ -96,6 +98,85 @@ void make_blobs(DataT* out, type); } +/** + * @brief GPU-equivalent of sklearn.datasets.make_blobs + * + * @tparam DataT output data type + * @tparam IdxT indexing arithmetic type + * + * @param[in] handle raft handle for managing expensive resources + * @param[out] out generated data [on device] + * [dim = n_rows x n_cols] + * @param[out] labels labels for the generated data [on device] + * [len = n_rows] + * @param[in] n_clusters number of clusters (or classes) to generate + * @param[in] centers centers of each of the cluster, pass a nullptr + * if you need this also to be generated randomly + * [on device] [dim = n_clusters x n_cols] + * @param[in] cluster_std standard deviation of each cluster center, + * pass a nullptr if this is to be read from the + * `cluster_std_scalar`. [on device] + * [len = n_clusters] + * @param[in] cluster_std_scalar if 'cluster_std' is nullptr, then use this as + * the std-dev across all dimensions. + * @param[in] shuffle shuffle the generated dataset and labels + * @param[in] center_box_min min value of box from which to pick cluster + * centers. Useful only if 'centers' is nullptr + * @param[in] center_box_max max value of box from which to pick cluster + * centers. Useful only if 'centers' is nullptr + * @param[in] seed seed for the RNG + * @param[in] type RNG type + */ +template +void make_blobs(raft::handle_t const& handle, + raft::device_matrix_view out, + raft::device_vector_view labels, + IdxT n_clusters = 5, + std::optional> centers = std::nullopt, + std::optional> const cluster_std = std::nullopt, + const DataT cluster_std_scalar = (DataT)1.0, + bool shuffle = true, + DataT center_box_min = (DataT)-10.0, + DataT center_box_max = (DataT)10.0, + uint64_t seed = 0ULL, + GeneratorType type = GenPhilox) +{ + if (centers.has_value()) { + RAFT_EXPECTS(centers.value().extent(0) == (std::size_t)n_clusters, + "n_centers must equal size of centers"); + } + + if (cluster_std.has_value()) { + RAFT_EXPECTS(cluster_std.value().extent(0) == (std::size_t)n_clusters, + "n_centers must equal size of cluster_std"); + } + + RAFT_EXPECTS(out.extent(0) == labels.extent(0), + "Number of labels must equal the number of row in output matrix"); + + RAFT_EXPECTS(out.is_contiguous(), "Output must be contiguous."); + + bool row_major = std::is_same::value; + + auto prm_centers = centers.has_value() ? centers.value().data() : nullptr; + auto prm_cluster_std = cluster_std.has_value() ? cluster_std.value().data() : nullptr; + + detail::make_blobs_caller(out.data(), + labels.data(), + (IdxT)out.extent(0), + (IdxT)out.extent(1), + n_clusters, + handle.get_stream(), + row_major, + prm_centers, + prm_cluster_std, + cluster_std_scalar, + shuffle, + center_box_min, + center_box_max, + seed, + type); +} } // end namespace raft::random -#endif +#endif \ No newline at end of file diff --git a/cpp/include/raft/spatial/knn/ball_cover.cuh b/cpp/include/raft/spatial/knn/ball_cover.cuh index df797ecca2..62cd5aa45c 100644 --- a/cpp/include/raft/spatial/knn/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/ball_cover.cuh @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef __BALL_COVER_H #define __BALL_COVER_H @@ -35,7 +34,7 @@ template & index) { - ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { detail::rbc_build_index(handle, index, detail::HaversineFunc()); } else if (index.metric == raft::distance::DistanceType::L2SqrtExpanded || @@ -85,7 +84,7 @@ void rbc_all_knn_query(const raft::handle_t& handle, bool perform_post_filtering = true, float weight = 1.0) { - ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { detail::rbc_all_knn_query(handle, index, @@ -152,7 +151,7 @@ void rbc_knn_query(const raft::handle_t& handle, bool perform_post_filtering = true, float weight = 1.0) { - ASSERT(index.n == 2, "Random ball cover currently only works in 2-dimensions"); + ASSERT(index.n <= 3, "only 2d and 3d vectors are supported in current implementation"); if (index.metric == raft::distance::DistanceType::Haversine) { detail::rbc_knn_query(handle, index, diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 1070f18b96..3c4f3d7323 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #if defined RAFT_DISTANCE_COMPILED #include #endif @@ -383,7 +384,7 @@ template return os; } -template +template void distanceLauncher(DataType* x, DataType* y, DataType* dist, @@ -393,14 +394,17 @@ void distanceLauncher(DataType* x, int k, DistanceInputs& params, DataType threshold, - char* workspace, - size_t worksize, cudaStream_t stream, - bool isRowMajor, DataType metric_arg = 2.0f) { - raft::distance::distance( - x, y, dist, m, n, k, workspace, worksize, stream, isRowMajor, metric_arg); + raft::handle_t handle(stream); + + auto x_v = make_device_matrix_view(x, m, k); + auto y_v = make_device_matrix_view(y, n, k); + auto dist_v = make_device_matrix_view(dist, m, n); + + raft::distance::distance( + handle, x_v, y_v, dist_v, metric_arg); } template @@ -446,25 +450,39 @@ class DistanceTest : public ::testing::TestWithParam> { } naiveDistance( dist_ref.data(), x.data(), y.data(), m, n, k, distanceType, isRowMajor, metric_arg, stream); - size_t worksize = raft::distance::getWorkspaceSize( - x.data(), y.data(), m, n, k); - rmm::device_uvector workspace(worksize, stream); + // size_t worksize = raft::distance::getWorkspaceSize( + // x.data(), y.data(), m, n, k); + // rmm::device_uvector workspace(worksize, stream); DataType threshold = -10000.f; - distanceLauncher(x.data(), - y.data(), - dist.data(), - dist2.data(), - m, - n, - k, - params, - threshold, - workspace.data(), - workspace.size(), - stream, - isRowMajor, - metric_arg); + + if (isRowMajor) { + distanceLauncher(x.data(), + y.data(), + dist.data(), + dist2.data(), + m, + n, + k, + params, + threshold, + stream, + metric_arg); + + } else { + distanceLauncher(x.data(), + y.data(), + dist.data(), + dist2.data(), + m, + n, + k, + params, + threshold, + stream, + metric_arg); + } handle.sync_stream(stream); } diff --git a/cpp/test/random/make_blobs.cu b/cpp/test/random/make_blobs.cu index b2b4ba9e66..36e6e3f838 100644 --- a/cpp/test/random/make_blobs.cu +++ b/cpp/test/random/make_blobs.cu @@ -19,6 +19,7 @@ #include #include #include +#include #include namespace raft { @@ -68,19 +69,19 @@ struct MakeBlobsInputs { T tolerance; int rows, cols, n_clusters; T std; - bool row_major, shuffle; + bool shuffle; raft::random::GeneratorType gtype; uint64_t seed; }; -template +template class MakeBlobsTest : public ::testing::TestWithParam> { public: MakeBlobsTest() : params(::testing::TestWithParam>::GetParam()), stream(handle.get_stream()), - mu_vec(params.cols * params.n_clusters, stream), - mean_var(2 * params.n_clusters * params.cols, stream) + mu_vec(make_device_matrix(handle, params.n_clusters, params.cols)), + mean_var(make_device_vector(handle, 2 * params.n_clusters * params.cols)) { } @@ -93,32 +94,31 @@ class MakeBlobsTest : public ::testing::TestWithParam> { auto len = params.rows * params.cols; raft::random::Rng r(params.seed, params.gtype); - rmm::device_uvector data(len, stream); - rmm::device_uvector labels(params.rows, stream); - rmm::device_uvector stats(2 * params.n_clusters * params.cols, stream); - rmm::device_uvector lens(params.n_clusters, stream); + auto data = make_device_matrix(handle, params.rows, params.cols); + auto labels = make_device_vector(handle, params.rows); + auto stats = make_device_vector(handle, 2 * params.n_clusters * params.cols); + auto lens = make_device_vector(handle, params.n_clusters); - RAFT_CUDA_TRY(cudaMemsetAsync(stats.data(), 0, stats.size() * sizeof(T), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(stats.data(), 0, stats.extent(0) * sizeof(T), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(lens.data(), 0, lens.extent(0) * sizeof(int), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(mean_var.data(), 0, mean_var.size() * sizeof(T), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(lens.data(), 0, lens.size() * sizeof(int), stream)); r.uniform(mu_vec.data(), params.cols * params.n_clusters, T(-10.0), T(10.0), stream); - T* sigma_vec = nullptr; - make_blobs(data.data(), - labels.data(), - params.rows, - params.cols, - params.n_clusters, - stream, - params.row_major, - mu_vec.data(), - sigma_vec, - params.std, - params.shuffle, - T(-10.0), - T(10.0), - params.seed, - params.gtype); + + make_blobs(handle, + data.view(), + labels.view(), + params.n_clusters, + std::make_optional(mu_vec.view()), + std::nullopt, + params.std, + params.shuffle, + T(-10.0), + T(10.0), + params.seed, + params.gtype); + + bool row_major = std::is_same::value; static const int threads = 128; meanKernel<<>>(stats.data(), lens.data(), @@ -127,10 +127,10 @@ class MakeBlobsTest : public ::testing::TestWithParam> { params.rows, params.cols, params.n_clusters, - params.row_major); + row_major); int len1 = params.n_clusters * params.cols; compute_mean_var<<>>( - mean_var.data(), stats.data(), lens.data(), params.n_clusters, params.cols, params.row_major); + mean_var.data(), stats.data(), lens.data(), params.n_clusters, params.cols, row_major); } void check() @@ -146,87 +146,66 @@ class MakeBlobsTest : public ::testing::TestWithParam> { raft::handle_t handle; cudaStream_t stream = 0; - rmm::device_uvector mu_vec, mean_var; + device_vector mean_var; + device_matrix mu_vec; int num_sigma; }; -typedef MakeBlobsTest MakeBlobsTestF; +typedef MakeBlobsTest MakeBlobsTestF_RowMajor; +typedef MakeBlobsTest MakeBlobsTestF_ColMajor; + const std::vector> inputsf_t = { - {0.0055, 1024, 32, 3, 1.f, true, false, raft::random::GenPhilox, 1234ULL}, - {0.011, 1024, 8, 3, 1.f, true, false, raft::random::GenPhilox, 1234ULL}, - {0.0055, 1024, 32, 3, 1.f, true, false, raft::random::GenPC, 1234ULL}, - {0.011, 1024, 8, 3, 1.f, true, false, raft::random::GenPC, 1234ULL}, - {0.0055, 1024, 32, 3, 1.f, false, false, raft::random::GenPhilox, 1234ULL}, - {0.011, 1024, 8, 3, 1.f, false, false, raft::random::GenPhilox, 1234ULL}, - {0.0055, 1024, 32, 3, 1.f, false, false, raft::random::GenPC, 1234ULL}, - {0.011, 1024, 8, 3, 1.f, false, false, raft::random::GenPC, 1234ULL}, - {0.0055, 1024, 32, 3, 1.f, true, true, raft::random::GenPhilox, 1234ULL}, - {0.011, 1024, 8, 3, 1.f, true, true, raft::random::GenPhilox, 1234ULL}, - {0.0055, 1024, 32, 3, 1.f, true, true, raft::random::GenPC, 1234ULL}, - {0.011, 1024, 8, 3, 1.f, true, true, raft::random::GenPC, 1234ULL}, - {0.0055, 1024, 32, 3, 1.f, false, true, raft::random::GenPhilox, 1234ULL}, - {0.011, 1024, 8, 3, 1.f, false, true, raft::random::GenPhilox, 1234ULL}, - {0.0055, 1024, 32, 3, 1.f, false, true, raft::random::GenPC, 1234ULL}, - {0.011, 1024, 8, 3, 1.f, false, true, raft::random::GenPC, 1234ULL}, - {0.0055, 5003, 32, 5, 1.f, true, false, raft::random::GenPhilox, 1234ULL}, - {0.011, 5003, 8, 5, 1.f, true, false, raft::random::GenPhilox, 1234ULL}, - {0.0055, 5003, 32, 5, 1.f, true, false, raft::random::GenPC, 1234ULL}, - {0.011, 5003, 8, 5, 1.f, true, false, raft::random::GenPC, 1234ULL}, - {0.0055, 5003, 32, 5, 1.f, false, false, raft::random::GenPhilox, 1234ULL}, - {0.011, 5003, 8, 5, 1.f, false, false, raft::random::GenPhilox, 1234ULL}, - {0.0055, 5003, 32, 5, 1.f, false, false, raft::random::GenPC, 1234ULL}, - {0.011, 5003, 8, 5, 1.f, false, false, raft::random::GenPC, 1234ULL}, - {0.0055, 5003, 32, 5, 1.f, true, true, raft::random::GenPhilox, 1234ULL}, - {0.011, 5003, 8, 5, 1.f, true, true, raft::random::GenPhilox, 1234ULL}, - {0.0055, 5003, 32, 5, 1.f, true, true, raft::random::GenPC, 1234ULL}, - {0.011, 5003, 8, 5, 1.f, true, true, raft::random::GenPC, 1234ULL}, - {0.0055, 5003, 32, 5, 1.f, false, true, raft::random::GenPhilox, 1234ULL}, - {0.011, 5003, 8, 5, 1.f, false, true, raft::random::GenPhilox, 1234ULL}, - {0.0055, 5003, 32, 5, 1.f, false, true, raft::random::GenPC, 1234ULL}, - {0.011, 5003, 8, 5, 1.f, false, true, raft::random::GenPC, 1234ULL}, + {0.0055, 1024, 32, 3, 1.f, false, raft::random::GenPhilox, 1234ULL}, + {0.011, 1024, 8, 3, 1.f, false, raft::random::GenPhilox, 1234ULL}, + {0.0055, 1024, 32, 3, 1.f, false, raft::random::GenPC, 1234ULL}, + {0.011, 1024, 8, 3, 1.f, false, raft::random::GenPC, 1234ULL}, + {0.0055, 1024, 32, 3, 1.f, true, raft::random::GenPhilox, 1234ULL}, + {0.011, 1024, 8, 3, 1.f, true, raft::random::GenPhilox, 1234ULL}, + {0.0055, 1024, 32, 3, 1.f, true, raft::random::GenPC, 1234ULL}, + {0.011, 1024, 8, 3, 1.f, true, raft::random::GenPC, 1234ULL}, + {0.0055, 5003, 32, 5, 1.f, false, raft::random::GenPhilox, 1234ULL}, + {0.011, 5003, 8, 5, 1.f, false, raft::random::GenPhilox, 1234ULL}, + {0.0055, 5003, 32, 5, 1.f, false, raft::random::GenPC, 1234ULL}, + {0.011, 5003, 8, 5, 1.f, false, raft::random::GenPC, 1234ULL}, + {0.0055, 5003, 32, 5, 1.f, true, raft::random::GenPhilox, 1234ULL}, + {0.011, 5003, 8, 5, 1.f, true, raft::random::GenPhilox, 1234ULL}, + {0.0055, 5003, 32, 5, 1.f, true, raft::random::GenPC, 1234ULL}, + {0.011, 5003, 8, 5, 1.f, true, raft::random::GenPC, 1234ULL}, }; -TEST_P(MakeBlobsTestF, Result) { check(); } -INSTANTIATE_TEST_CASE_P(MakeBlobsTests, MakeBlobsTestF, ::testing::ValuesIn(inputsf_t)); +TEST_P(MakeBlobsTestF_RowMajor, Result) { check(); } +INSTANTIATE_TEST_CASE_P(MakeBlobsTests, MakeBlobsTestF_RowMajor, ::testing::ValuesIn(inputsf_t)); + +TEST_P(MakeBlobsTestF_ColMajor, Result) { check(); } +INSTANTIATE_TEST_CASE_P(MakeBlobsTests, MakeBlobsTestF_ColMajor, ::testing::ValuesIn(inputsf_t)); + +typedef MakeBlobsTest MakeBlobsTestD_RowMajor; +typedef MakeBlobsTest MakeBlobsTestD_ColMajor; -typedef MakeBlobsTest MakeBlobsTestD; const std::vector> inputsd_t = { - {0.0055, 1024, 32, 3, 1.0, true, false, raft::random::GenPhilox, 1234ULL}, - {0.011, 1024, 8, 3, 1.0, true, false, raft::random::GenPhilox, 1234ULL}, - {0.0055, 1024, 32, 3, 1.0, true, false, raft::random::GenPC, 1234ULL}, - {0.011, 1024, 8, 3, 1.0, true, false, raft::random::GenPC, 1234ULL}, - {0.0055, 1024, 32, 3, 1.0, false, false, raft::random::GenPhilox, 1234ULL}, - {0.011, 1024, 8, 3, 1.0, false, false, raft::random::GenPhilox, 1234ULL}, - {0.0055, 1024, 32, 3, 1.0, false, false, raft::random::GenPC, 1234ULL}, - {0.011, 1024, 8, 3, 1.0, false, false, raft::random::GenPC, 1234ULL}, - {0.0055, 1024, 32, 3, 1.0, true, true, raft::random::GenPhilox, 1234ULL}, - {0.011, 1024, 8, 3, 1.0, true, true, raft::random::GenPhilox, 1234ULL}, - {0.0055, 1024, 32, 3, 1.0, true, true, raft::random::GenPC, 1234ULL}, - {0.011, 1024, 8, 3, 1.0, true, true, raft::random::GenPC, 1234ULL}, - {0.0055, 1024, 32, 3, 1.0, false, true, raft::random::GenPhilox, 1234ULL}, - {0.011, 1024, 8, 3, 1.0, false, true, raft::random::GenPhilox, 1234ULL}, - {0.0055, 1024, 32, 3, 1.0, false, true, raft::random::GenPC, 1234ULL}, - {0.011, 1024, 8, 3, 1.0, false, true, raft::random::GenPC, 1234ULL}, - - {0.0055, 5003, 32, 5, 1.0, true, false, raft::random::GenPhilox, 1234ULL}, - {0.011, 5003, 8, 5, 1.0, true, false, raft::random::GenPhilox, 1234ULL}, - {0.0055, 5003, 32, 5, 1.0, true, false, raft::random::GenPC, 1234ULL}, - {0.011, 5003, 8, 5, 1.0, true, false, raft::random::GenPC, 1234ULL}, - {0.0055, 5003, 32, 5, 1.0, false, false, raft::random::GenPhilox, 1234ULL}, - {0.011, 5003, 8, 5, 1.0, false, false, raft::random::GenPhilox, 1234ULL}, - {0.0055, 5003, 32, 5, 1.0, false, false, raft::random::GenPC, 1234ULL}, - {0.011, 5003, 8, 5, 1.0, false, false, raft::random::GenPC, 1234ULL}, - {0.0055, 5003, 32, 5, 1.0, true, true, raft::random::GenPhilox, 1234ULL}, - {0.011, 5003, 8, 5, 1.0, true, true, raft::random::GenPhilox, 1234ULL}, - {0.0055, 5003, 32, 5, 1.0, true, true, raft::random::GenPC, 1234ULL}, - {0.011, 5003, 8, 5, 1.0, true, true, raft::random::GenPC, 1234ULL}, - {0.0055, 5003, 32, 5, 1.0, false, true, raft::random::GenPhilox, 1234ULL}, - {0.011, 5003, 8, 5, 1.0, false, true, raft::random::GenPhilox, 1234ULL}, - {0.0055, 5003, 32, 5, 1.0, false, true, raft::random::GenPC, 1234ULL}, - {0.011, 5003, 8, 5, 1.0, false, true, raft::random::GenPC, 1234ULL}, + {0.0055, 1024, 32, 3, 1.0, false, raft::random::GenPhilox, 1234ULL}, + {0.011, 1024, 8, 3, 1.0, false, raft::random::GenPhilox, 1234ULL}, + {0.0055, 1024, 32, 3, 1.0, false, raft::random::GenPC, 1234ULL}, + {0.011, 1024, 8, 3, 1.0, false, raft::random::GenPC, 1234ULL}, + {0.0055, 1024, 32, 3, 1.0, true, raft::random::GenPhilox, 1234ULL}, + {0.011, 1024, 8, 3, 1.0, true, raft::random::GenPhilox, 1234ULL}, + {0.0055, 1024, 32, 3, 1.0, true, raft::random::GenPC, 1234ULL}, + {0.011, 1024, 8, 3, 1.0, true, raft::random::GenPC, 1234ULL}, + + {0.0055, 5003, 32, 5, 1.0, false, raft::random::GenPhilox, 1234ULL}, + {0.011, 5003, 8, 5, 1.0, false, raft::random::GenPhilox, 1234ULL}, + {0.0055, 5003, 32, 5, 1.0, false, raft::random::GenPC, 1234ULL}, + {0.011, 5003, 8, 5, 1.0, false, raft::random::GenPC, 1234ULL}, + {0.0055, 5003, 32, 5, 1.0, true, raft::random::GenPhilox, 1234ULL}, + {0.011, 5003, 8, 5, 1.0, true, raft::random::GenPhilox, 1234ULL}, + {0.0055, 5003, 32, 5, 1.0, true, raft::random::GenPC, 1234ULL}, + {0.011, 5003, 8, 5, 1.0, true, raft::random::GenPC, 1234ULL}, }; -TEST_P(MakeBlobsTestD, Result) { check(); } -INSTANTIATE_TEST_CASE_P(MakeBlobsTests, MakeBlobsTestD, ::testing::ValuesIn(inputsd_t)); +TEST_P(MakeBlobsTestD_RowMajor, Result) { check(); } +INSTANTIATE_TEST_CASE_P(MakeBlobsTests, MakeBlobsTestD_RowMajor, ::testing::ValuesIn(inputsd_t)); + +TEST_P(MakeBlobsTestD_ColMajor, Result) { check(); } +INSTANTIATE_TEST_CASE_P(MakeBlobsTests, MakeBlobsTestD_ColMajor, ::testing::ValuesIn(inputsd_t)); } // end namespace random } // end namespace raft