From 01dd067d3614745e5928ac09c807d9b1c35e9f9b Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 31 Oct 2022 16:26:07 -0700 Subject: [PATCH 01/12] Expose `linalg::dot` in public API Closes https://github.com/rapidsai/raft/issues/805 --- cpp/include/raft/core/device_mdspan.hpp | 22 +++- cpp/include/raft/linalg/dot.cuh | 70 +++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/linalg/axpy.cu | 11 -- cpp/test/linalg/dot.cu | 152 ++++++++++++++++++++++++ 5 files changed, 244 insertions(+), 12 deletions(-) create mode 100644 cpp/include/raft/linalg/dot.cuh create mode 100644 cpp/test/linalg/dot.cu diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index 394ea228b4..dc2acb13df 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -276,4 +276,24 @@ auto make_device_vector_view(ElementType* ptr, IndexType n) return device_vector_view{ptr, n}; } -} // end namespace raft \ No newline at end of file +/** + * @brief Create a 1-dim mdspan instance for device pointer, using a strided layout + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr on device to wrap + * @param[in] n number of elements in pointer + * @param[in] stride stride between elements + * @return raft::device_vector_view + */ +template +auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride) +{ + vector_extent exts{n}; + std::array strides{stride}; + auto layout = typename LayoutPolicy::template mapping>{exts, strides}; + return device_vector_view{ptr, layout}; +} +} // end namespace raft diff --git a/cpp/include/raft/linalg/dot.cuh b/cpp/include/raft/linalg/dot.cuh new file mode 100644 index 0000000000..ec3631a042 --- /dev/null +++ b/cpp/include/raft/linalg/dot.cuh @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __DOT_H +#define __DOT_H + +#pragma once + +#include + +#include +#include +#include + +namespace raft::linalg { + +/** + * @brief Computes the dot product of two vectors. + * @tparam InputType1 raft::device_mdspan for the first input vector + * @tparam InputType2 raft::device_mdspan for the second input vector + * @tparam OutputType Either a host_scalar_view or device_scalar_view for the output + * @param[in] handle raft::handle_t + * @param[in] x First input vector + * @param[in] y Second input vector + * @param[out] out The output dot product between the x and y vectors + */ +template , + typename = raft::enable_if_input_device_mdspan, + typename = raft::enable_if_output_mdspan> +void dot(const raft::handle_t& handle, InputType1 x, InputType2 y, OutputType out) +{ + RAFT_EXPECTS(x.size() == y.size(), + "Size mismatch between x and y input vectors in raft::linalg::dot"); + + // Right now the inputs and outputs need to all have the same value_type (float/double etc). + // Try to output a meaningful compiler error if mismatched types are passed here. + // Note: In the future we could remove this restriction using the cublasDotEx function + // in the cublas wrapper call, instead of the cublassdot and cublasddot functions. + static_assert(std::is_same_v, + "Both input vectors need to have the same value_type in raft::linalg::dot call"); + static_assert( + std::is_same_v, + "Input vectors and output scalar need to have the same value_type in raft::linalg::dot call"); + + RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(), + x.size(), + x.data_handle(), + x.stride(0), + y.data_handle(), + y.stride(0), + out.data_handle(), + handle.get_stream())); +} +} // namespace raft::linalg +#endif diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 4280be91ff..87b587d032 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -134,6 +134,7 @@ if(BUILD_TESTS) test/linalg/cholesky_r1.cu test/linalg/coalesced_reduction.cu test/linalg/divide.cu + test/linalg/dot.cu test/linalg/eig.cu test/linalg/eig_sel.cu test/linalg/gemm_layout.cu diff --git a/cpp/test/linalg/axpy.cu b/cpp/test/linalg/axpy.cu index fdab845914..8ed0302b86 100644 --- a/cpp/test/linalg/axpy.cu +++ b/cpp/test/linalg/axpy.cu @@ -31,17 +31,6 @@ __global__ void naiveAxpy(const int n, const T alpha, const T* x, T* y, int incx if (idx < n) { y[idx * incy] += alpha * x[idx * incx]; } } -template -auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride) -{ - vector_extent exts{n}; - std::array strides{stride}; - auto layout = typename LayoutPolicy::mapping>{exts, strides}; - return device_vector_view{ptr, layout}; -} - template struct AxpyInputs { OutType tolerance; diff --git a/cpp/test/linalg/dot.cu b/cpp/test/linalg/dot.cu new file mode 100644 index 0000000000..24c5aa4f7b --- /dev/null +++ b/cpp/test/linalg/dot.cu @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "../test_utils.h" +#include +#include +#include +#include + +namespace raft { +namespace linalg { + +// Reference dot implementation. +template +__global__ void naiveDot(const int n, const T* x, int incx, const T* y, int incy, T* out) +{ + T sum = 0; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + sum += x[i * incx] * y[i * incy]; + } + atomicAdd(out, sum); +} + +template +struct DotInputs { + OutType tolerance; + IndexType len; + IndexType incx; + IndexType incy; + unsigned long long int seed; +}; + +template +class DotTest : public ::testing::TestWithParam> { + protected: + raft::handle_t handle; + DotInputs params; + rmm::device_scalar output; + rmm::device_scalar refoutput; + + public: + DotTest() + : testing::TestWithParam>(), + output(0, handle.get_stream()), + refoutput(0, handle.get_stream()) + { + handle.sync_stream(); + } + + protected: + void SetUp() override + { + params = ::testing::TestWithParam>::GetParam(); + + cudaStream_t stream = handle.get_stream(); + + raft::random::RngState r(params.seed); + + int x_len = params.len * params.incx; + int y_len = params.len * params.incy; + + rmm::device_uvector x(x_len, stream); + rmm::device_uvector y(y_len, stream); + uniform(handle, r, x.data(), x_len, T(-1.0), T(1.0)); + uniform(handle, r, y.data(), y_len, T(-1.0), T(1.0)); + + naiveDot<<<256, 256, 0, stream>>>( + params.len, x.data(), params.incx, y.data(), params.incy, refoutput.data()); + + auto out_view = make_device_scalar_view(output.data()); + + if ((params.incx > 1) && (params.incy > 1)) { + dot(handle, + make_strided_device_vector_view(x.data(), params.len, params.incx), + make_strided_device_vector_view(y.data(), params.len, params.incy), + out_view); + } else if (params.incx > 1) { + dot(handle, + make_strided_device_vector_view(x.data(), params.len, params.incx), + make_device_vector_view(y.data(), params.len), + out_view); + } else if (params.incy > 1) { + dot(handle, + make_device_vector_view(x.data(), params.len), + make_strided_device_vector_view(y.data(), params.len, params.incy), + out_view); + } else { + dot(handle, + make_device_vector_view(x.data(), params.len), + make_device_vector_view(y.data(), params.len), + out_view); + } + handle.sync_stream(); + } + + void TearDown() override {} +}; + +const std::vector> inputsf = { + {0.0001f, 1024 * 1024, 1, 1, 1234ULL}, + {0.0001f, 16 * 1024 * 1024, 1, 1, 1234ULL}, + {0.0001f, 98689, 1, 1, 1234ULL}, + {0.0001f, 4 * 1024 * 1024, 1, 1, 1234ULL}, + {0.0001f, 1024 * 1024, 4, 1, 1234ULL}, + {0.0001f, 1024 * 1024, 1, 3, 1234ULL}, + {0.0001f, 1024 * 1024, 4, 3, 1234ULL}, +}; + +const std::vector> inputsd = { + {0.000001f, 1024 * 1024, 1, 1, 1234ULL}, + {0.000001f, 16 * 1024 * 1024, 1, 1, 1234ULL}, + {0.000001f, 98689, 1, 1, 1234ULL}, + {0.000001f, 4 * 1024 * 1024, 1, 1, 1234ULL}, + {0.000001f, 1024 * 1024, 4, 1, 1234ULL}, + {0.000001f, 1024 * 1024, 1, 3, 1234ULL}, + {0.000001f, 1024 * 1024, 4, 3, 1234ULL}, +}; + +typedef DotTest DotTestF; +TEST_P(DotTestF, Result) +{ + ASSERT_TRUE(raft::devArrMatch( + refoutput.data(), output.data(), 1, raft::CompareApprox(params.tolerance))); +} + +typedef DotTest DotTestD; +TEST_P(DotTestD, Result) +{ + ASSERT_TRUE(raft::devArrMatch( + refoutput.data(), output.data(), 1, raft::CompareApprox(params.tolerance))); +} + +INSTANTIATE_TEST_SUITE_P(DotTests, DotTestF, ::testing::ValuesIn(inputsf)); + +INSTANTIATE_TEST_SUITE_P(DotTests, DotTestD, ::testing::ValuesIn(inputsd)); + +} // end namespace linalg +} // end namespace raft From e6a5bb16927b2c310b84a6671abab9883068fa64 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 31 Oct 2022 16:51:13 -0700 Subject: [PATCH 02/12] formatting --- cpp/include/raft/core/device_mdspan.hpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index dc2acb13df..da28e48e9b 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -286,9 +286,7 @@ auto make_device_vector_view(ElementType* ptr, IndexType n) * @param[in] stride stride between elements * @return raft::device_vector_view */ -template +template auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride) { vector_extent exts{n}; From f376c513b50b5b7a25dc9593fb725548daaecfd3 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 1 Nov 2022 14:35:14 -0700 Subject: [PATCH 03/12] Updates from code review --- cpp/include/raft/core/device_mdspan.hpp | 33 +++++++++---------------- cpp/include/raft/linalg/dot.cuh | 33 +++++++++---------------- cpp/test/linalg/axpy.cu | 19 ++++++++------ cpp/test/linalg/dot.cu | 23 +++++++++-------- 4 files changed, 47 insertions(+), 61 deletions(-) diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index da28e48e9b..8b894b28f3 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -266,32 +266,23 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col * @tparam LayoutPolicy policy for strides and layout ordering * @param[in] ptr on device to wrap * @param[in] n number of elements in pointer + * @param[in] stride the stride between consecutive elements in the vector. Setting to a value + * other than 1 requires LayoutPolicy to be set to layout_stride * @return raft::device_vector_view */ template -auto make_device_vector_view(ElementType* ptr, IndexType n) +auto make_device_vector_view(ElementType* ptr, IndexType n, IndexType stride = 1) { - return device_vector_view{ptr, n}; -} - -/** - * @brief Create a 1-dim mdspan instance for device pointer, using a strided layout - * @tparam ElementType the data type of the vector elements - * @tparam IndexType the index type of the extents - * @tparam LayoutPolicy policy for strides and layout ordering - * @param[in] ptr on device to wrap - * @param[in] n number of elements in pointer - * @param[in] stride stride between elements - * @return raft::device_vector_view - */ -template -auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride) -{ - vector_extent exts{n}; - std::array strides{stride}; - auto layout = typename LayoutPolicy::template mapping>{exts, strides}; - return device_vector_view{ptr, layout}; + if constexpr (std::is_same_v) { + vector_extent exts{n}; + std::array strides{stride}; + auto layout = typename LayoutPolicy::template mapping>{exts, strides}; + return device_vector_view{ptr, layout}; + } else { + RAFT_EXPECTS(stride == 1, "Having a stride != 1 requires a layout_stride LayoutPolicy"); + return device_vector_view{ptr, n}; + } } } // end namespace raft diff --git a/cpp/include/raft/linalg/dot.cuh b/cpp/include/raft/linalg/dot.cuh index ec3631a042..3acfe9f939 100644 --- a/cpp/include/raft/linalg/dot.cuh +++ b/cpp/include/raft/linalg/dot.cuh @@ -25,38 +25,27 @@ #include namespace raft::linalg { - /** * @brief Computes the dot product of two vectors. - * @tparam InputType1 raft::device_mdspan for the first input vector - * @tparam InputType2 raft::device_mdspan for the second input vector - * @tparam OutputType Either a host_scalar_view or device_scalar_view for the output * @param[in] handle raft::handle_t * @param[in] x First input vector * @param[in] y Second input vector - * @param[out] out The output dot product between the x and y vectors + * @param[out] out The output dot product between the x and y vectors. + * @note The out parameter can be either a host_scalar_view or device_scalar_view */ -template , - typename = raft::enable_if_input_device_mdspan, - typename = raft::enable_if_output_mdspan> -void dot(const raft::handle_t& handle, InputType1 x, InputType2 y, OutputType out) +template +void dot(const raft::handle_t& handle, + raft::device_vector_view x, + raft::device_vector_view y, + raft::device_scalar_view out) { RAFT_EXPECTS(x.size() == y.size(), "Size mismatch between x and y input vectors in raft::linalg::dot"); - // Right now the inputs and outputs need to all have the same value_type (float/double etc). - // Try to output a meaningful compiler error if mismatched types are passed here. - // Note: In the future we could remove this restriction using the cublasDotEx function - // in the cublas wrapper call, instead of the cublassdot and cublasddot functions. - static_assert(std::is_same_v, - "Both input vectors need to have the same value_type in raft::linalg::dot call"); - static_assert( - std::is_same_v, - "Input vectors and output scalar need to have the same value_type in raft::linalg::dot call"); - RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(), x.size(), x.data_handle(), diff --git a/cpp/test/linalg/axpy.cu b/cpp/test/linalg/axpy.cu index 8ed0302b86..6d932c0c1b 100644 --- a/cpp/test/linalg/axpy.cu +++ b/cpp/test/linalg/axpy.cu @@ -41,11 +41,11 @@ struct AxpyInputs { unsigned long long int seed; }; -template +template class AxpyTest : public ::testing::TestWithParam> { protected: raft::handle_t handle; - AxpyInputs params; + AxpyInputs params; rmm::device_uvector refy; rmm::device_uvector y; @@ -67,8 +67,8 @@ class AxpyTest : public ::testing::TestWithParam> { raft::random::RngState r(params.seed); - int x_len = params.len * params.incx; - int y_len = params.len * params.incy; + IndexType x_len = params.len * params.incx; + IndexType y_len = params.len * params.incy; rmm::device_uvector x(x_len, stream); y.resize(y_len, stream); refy.resize(y_len, stream); @@ -89,18 +89,21 @@ class AxpyTest : public ::testing::TestWithParam> { if ((params.incx > 1) && (params.incy > 1)) { axpy(handle, make_host_scalar_view(¶ms.alpha), - make_strided_device_vector_view(x.data(), params.len, params.incx), - make_strided_device_vector_view(y.data(), params.len, params.incy)); + make_device_vector_view( + x.data(), params.len, params.incx), + make_device_vector_view(y.data(), params.len, params.incy)); + } else if (params.incx > 1) { axpy(handle, make_host_scalar_view(¶ms.alpha), - make_strided_device_vector_view(x.data(), params.len, params.incx), + make_device_vector_view( + x.data(), params.len, params.incx), make_device_vector_view(y.data(), params.len)); } else if (params.incy > 1) { axpy(handle, make_host_scalar_view(¶ms.alpha), make_device_vector_view(x.data(), params.len), - make_strided_device_vector_view(y.data(), params.len, params.incy)); + make_device_vector_view(y.data(), params.len, params.incy)); } else { axpy(handle, make_host_scalar_view(¶ms.alpha), diff --git a/cpp/test/linalg/dot.cu b/cpp/test/linalg/dot.cu index 24c5aa4f7b..67e9ddfac7 100644 --- a/cpp/test/linalg/dot.cu +++ b/cpp/test/linalg/dot.cu @@ -23,7 +23,6 @@ namespace raft { namespace linalg { - // Reference dot implementation. template __global__ void naiveDot(const int n, const T* x, int incx, const T* y, int incy, T* out) @@ -44,11 +43,11 @@ struct DotInputs { unsigned long long int seed; }; -template +template class DotTest : public ::testing::TestWithParam> { protected: raft::handle_t handle; - DotInputs params; + DotInputs params; rmm::device_scalar output; rmm::device_scalar refoutput; @@ -70,8 +69,8 @@ class DotTest : public ::testing::TestWithParam> { raft::random::RngState r(params.seed); - int x_len = params.len * params.incx; - int y_len = params.len * params.incy; + IndexType x_len = params.len * params.incx; + IndexType y_len = params.len * params.incy; rmm::device_uvector x(x_len, stream); rmm::device_uvector y(y_len, stream); @@ -81,22 +80,26 @@ class DotTest : public ::testing::TestWithParam> { naiveDot<<<256, 256, 0, stream>>>( params.len, x.data(), params.incx, y.data(), params.incy, refoutput.data()); - auto out_view = make_device_scalar_view(output.data()); + auto out_view = make_device_scalar_view(output.data()); if ((params.incx > 1) && (params.incy > 1)) { dot(handle, - make_strided_device_vector_view(x.data(), params.len, params.incx), - make_strided_device_vector_view(y.data(), params.len, params.incy), + make_device_vector_view( + x.data(), params.len, params.incx), + make_device_vector_view( + y.data(), params.len, params.incy), out_view); } else if (params.incx > 1) { dot(handle, - make_strided_device_vector_view(x.data(), params.len, params.incx), + make_device_vector_view( + x.data(), params.len, params.incx), make_device_vector_view(y.data(), params.len), out_view); } else if (params.incy > 1) { dot(handle, make_device_vector_view(x.data(), params.len), - make_strided_device_vector_view(y.data(), params.len, params.incy), + make_device_vector_view( + y.data(), params.len, params.incy), out_view); } else { dot(handle, From 9c9efe884589d9a010f4e9b370ecb3a6e2393932 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 1 Nov 2022 14:50:02 -0700 Subject: [PATCH 04/12] Update axpy to take a device_vector_view --- cpp/include/raft/linalg/axpy.cuh | 69 +++++++++++++++----------------- cpp/include/raft/linalg/dot.cuh | 8 ++-- 2 files changed, 36 insertions(+), 41 deletions(-) diff --git a/cpp/include/raft/linalg/axpy.cuh b/cpp/include/raft/linalg/axpy.cuh index d3abfed6e6..8e34ab3f33 100644 --- a/cpp/include/raft/linalg/axpy.cuh +++ b/cpp/include/raft/linalg/axpy.cuh @@ -62,66 +62,61 @@ void axpy(const raft::handle_t& handle, * @brief axpy function * It computes the following equation: y = alpha * x + y * - * @tparam InType Type raft::device_mdspan - * @tparam ScalarIdxType Index Type of scalar * @param [in] handle raft::handle_t * @param [in] alpha raft::device_scalar_view * @param [in] x Input vector * @param [inout] y Output vector */ -template , - typename = raft::enable_if_output_device_mdspan> +template void axpy(const raft::handle_t& handle, - raft::device_scalar_view alpha, - InType x, - OutType y) + raft::device_scalar_view alpha, + raft::device_vector_view x, + raft::device_vector_view y) { RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input"); - axpy(handle, - y.size(), - alpha.data_handle(), - x.data_handle(), - x.stride(0), - y.data_handle(), - y.stride(0), - handle.get_stream()); + axpy(handle, + y.size(), + alpha.data_handle(), + x.data_handle(), + x.stride(0), + y.data_handle(), + y.stride(0), + handle.get_stream()); } /** * @brief axpy function * It computes the following equation: y = alpha * x + y - * - * @tparam MdspanType Type raft::device_mdspan - * @tparam ScalarIdxType Index Type of scalar * @param [in] handle raft::handle_t * @param [in] alpha raft::device_scalar_view * @param [in] x Input vector * @param [inout] y Output vector */ -template , - typename = raft::enable_if_output_device_mdspan> +template void axpy(const raft::handle_t& handle, - raft::host_scalar_view alpha, - InType x, - OutType y) + raft::host_scalar_view alpha, + raft::device_vector_view x, + raft::device_vector_view y) { RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input"); - axpy(handle, - y.size(), - alpha.data_handle(), - x.data_handle(), - x.stride(0), - y.data_handle(), - y.stride(0), - handle.get_stream()); + axpy(handle, + y.size(), + alpha.data_handle(), + x.data_handle(), + x.stride(0), + y.data_handle(), + y.stride(0), + handle.get_stream()); } /** @} */ // end of group axpy diff --git a/cpp/include/raft/linalg/dot.cuh b/cpp/include/raft/linalg/dot.cuh index 3acfe9f939..783929c175 100644 --- a/cpp/include/raft/linalg/dot.cuh +++ b/cpp/include/raft/linalg/dot.cuh @@ -33,15 +33,15 @@ namespace raft::linalg { * @param[out] out The output dot product between the x and y vectors. * @note The out parameter can be either a host_scalar_view or device_scalar_view */ -template void dot(const raft::handle_t& handle, - raft::device_vector_view x, - raft::device_vector_view y, - raft::device_scalar_view out) + raft::device_vector_view x, + raft::device_vector_view y, + raft::device_scalar_view out) { RAFT_EXPECTS(x.size() == y.size(), "Size mismatch between x and y input vectors in raft::linalg::dot"); From 2bf4eaebb9caeff8825dd81b83ce53b9db19ae0e Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 8 Nov 2022 09:48:28 -0800 Subject: [PATCH 05/12] Changes from codereview * Remove default types, * Try to fix up factory functions for creating strided vector views * Add dot funcction that takes host scalar / host_scalar_view --- cpp/include/raft/core/device_mdspan.hpp | 49 ++++++++++++------ cpp/include/raft/linalg/axpy.cuh | 16 +++--- cpp/include/raft/linalg/dot.cuh | 69 +++++++++++++++++++++++-- cpp/test/linalg/axpy.cu | 14 ++--- cpp/test/linalg/dot.cu | 8 +-- 5 files changed, 117 insertions(+), 39 deletions(-) diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index 8b894b28f3..470ade3d6e 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -266,23 +266,42 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col * @tparam LayoutPolicy policy for strides and layout ordering * @param[in] ptr on device to wrap * @param[in] n number of elements in pointer - * @param[in] stride the stride between consecutive elements in the vector. Setting to a value - * other than 1 requires LayoutPolicy to be set to layout_stride * @return raft::device_vector_view */ -template -auto make_device_vector_view(ElementType* ptr, IndexType n, IndexType stride = 1) +template +auto make_device_vector_view(ElementType* ptr, IndexType n) +{ + return device_vector_view{ptr, n}; +} + +/** + * @brief Create a 1-dim mdspan instance for device pointer. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr on device to wrap + * @param[in] mapping The layout mapping to use for this vector + * @return raft::device_vector_view + */ +template +auto make_device_vector_view( + ElementType* ptr, + const typename LayoutPolicy::template mapping>& mapping) +{ + return device_vector_view{ptr, mapping}; +} + +/** + * @brief Construct a strided vector layout mapping + * @tparam IndexType the index type of the extents + * @params[in] n the number of elements in the vector + * @params[in] stride the stride between elements in the vector + */ +template +auto make_vector_strided_layout(IndexType n, IndexType stride) { - if constexpr (std::is_same_v) { - vector_extent exts{n}; - std::array strides{stride}; - auto layout = typename LayoutPolicy::template mapping>{exts, strides}; - return device_vector_view{ptr, layout}; - } else { - RAFT_EXPECTS(stride == 1, "Having a stride != 1 requires a layout_stride LayoutPolicy"); - return device_vector_view{ptr, n}; - } + vector_extent exts{n}; + std::array strides{stride}; + return layout_stride::mapping>{exts, strides}; } } // end namespace raft diff --git a/cpp/include/raft/linalg/axpy.cuh b/cpp/include/raft/linalg/axpy.cuh index 8e34ab3f33..88b065c8b0 100644 --- a/cpp/include/raft/linalg/axpy.cuh +++ b/cpp/include/raft/linalg/axpy.cuh @@ -68,10 +68,10 @@ void axpy(const raft::handle_t& handle, * @param [inout] y Output vector */ template + typename IndexType, + typename InLayoutPolicy, + typename OutLayoutPolicy, + typename ScalarIdxType> void axpy(const raft::handle_t& handle, raft::device_scalar_view alpha, raft::device_vector_view x, @@ -98,10 +98,10 @@ void axpy(const raft::handle_t& handle, * @param [inout] y Output vector */ template + typename IndexType, + typename InLayoutPolicy, + typename OutLayoutPolicy, + typename ScalarIdxType> void axpy(const raft::handle_t& handle, raft::host_scalar_view alpha, raft::device_vector_view x, diff --git a/cpp/include/raft/linalg/dot.cuh b/cpp/include/raft/linalg/dot.cuh index 783929c175..2f24f4856e 100644 --- a/cpp/include/raft/linalg/dot.cuh +++ b/cpp/include/raft/linalg/dot.cuh @@ -31,13 +31,12 @@ namespace raft::linalg { * @param[in] x First input vector * @param[in] y Second input vector * @param[out] out The output dot product between the x and y vectors. - * @note The out parameter can be either a host_scalar_view or device_scalar_view */ template + typename IndexType, + typename ScalarIndexType, + typename LayoutPolicy1, + typename LayoutPolicy2> void dot(const raft::handle_t& handle, raft::device_vector_view x, raft::device_vector_view y, @@ -55,5 +54,65 @@ void dot(const raft::handle_t& handle, out.data_handle(), handle.get_stream())); } + +/** + * @brief Computes the dot product of two vectors. + * @param[in] handle raft::handle_t + * @param[in] x First input vector + * @param[in] y Second input vector + * @param[out] out The output dot product between the x and y vectors. + */ +template +void dot(const raft::handle_t& handle, + raft::device_vector_view x, + raft::device_vector_view y, + raft::host_scalar_view out) +{ + RAFT_EXPECTS(x.size() == y.size(), + "Size mismatch between x and y input vectors in raft::linalg::dot"); + + RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(), + x.size(), + x.data_handle(), + x.stride(0), + y.data_handle(), + y.stride(0), + out.data_handle(), + handle.get_stream())); +} + +/** + * @brief Computes the dot product of two vectors. + * @param[in] handle raft::handle_t + * @param[in] x First input vector + * @param[in] y Second input vector + * @param[out] out The output dot product between the x and y vectors. + */ +template +void dot(const raft::handle_t& handle, + raft::device_vector_view x, + raft::device_vector_view y, + ElementType* out) +{ + RAFT_EXPECTS(x.size() == y.size(), + "Size mismatch between x and y input vectors in raft::linalg::dot"); + + RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(), + x.size(), + x.data_handle(), + x.stride(0), + y.data_handle(), + y.stride(0), + out, + handle.get_stream())); +} } // namespace raft::linalg #endif diff --git a/cpp/test/linalg/axpy.cu b/cpp/test/linalg/axpy.cu index 6d932c0c1b..fb07e2ec2f 100644 --- a/cpp/test/linalg/axpy.cu +++ b/cpp/test/linalg/axpy.cu @@ -22,7 +22,6 @@ namespace raft { namespace linalg { - // Reference axpy implementation. template __global__ void naiveAxpy(const int n, const T alpha, const T* x, T* y, int incx, int incy) @@ -90,20 +89,21 @@ class AxpyTest : public ::testing::TestWithParam> { axpy(handle, make_host_scalar_view(¶ms.alpha), make_device_vector_view( - x.data(), params.len, params.incx), - make_device_vector_view(y.data(), params.len, params.incy)); - + x.data(), make_vector_strided_layout(params.len, params.incx)), + make_device_vector_view( + y.data(), make_vector_strided_layout(params.len, params.incy))); } else if (params.incx > 1) { axpy(handle, make_host_scalar_view(¶ms.alpha), make_device_vector_view( - x.data(), params.len, params.incx), - make_device_vector_view(y.data(), params.len)); + x.data(), make_vector_strided_layout(params.len, params.incx)), + make_device_vector_view(y.data(), params.len)); } else if (params.incy > 1) { axpy(handle, make_host_scalar_view(¶ms.alpha), make_device_vector_view(x.data(), params.len), - make_device_vector_view(y.data(), params.len, params.incy)); + make_device_vector_view( + y.data(), make_vector_strided_layout(params.len, params.incy))); } else { axpy(handle, make_host_scalar_view(¶ms.alpha), diff --git a/cpp/test/linalg/dot.cu b/cpp/test/linalg/dot.cu index 67e9ddfac7..f0bce47189 100644 --- a/cpp/test/linalg/dot.cu +++ b/cpp/test/linalg/dot.cu @@ -85,21 +85,21 @@ class DotTest : public ::testing::TestWithParam> { if ((params.incx > 1) && (params.incy > 1)) { dot(handle, make_device_vector_view( - x.data(), params.len, params.incx), + x.data(), make_vector_strided_layout(params.len, params.incx)), make_device_vector_view( - y.data(), params.len, params.incy), + y.data(), make_vector_strided_layout(params.len, params.incy)), out_view); } else if (params.incx > 1) { dot(handle, make_device_vector_view( - x.data(), params.len, params.incx), + x.data(), make_vector_strided_layout(params.len, params.incx)), make_device_vector_view(y.data(), params.len), out_view); } else if (params.incy > 1) { dot(handle, make_device_vector_view(x.data(), params.len), make_device_vector_view( - y.data(), params.len, params.incy), + y.data(), make_vector_strided_layout(params.len, params.incy)), out_view); } else { dot(handle, From 2e1c0e65bc7309ad52e7a6f1cc5dceebd2ae05b1 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 8 Nov 2022 11:07:02 -0800 Subject: [PATCH 06/12] remove dot w/ host pointer overload --- cpp/include/raft/linalg/dot.cuh | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/cpp/include/raft/linalg/dot.cuh b/cpp/include/raft/linalg/dot.cuh index 2f24f4856e..48577650bc 100644 --- a/cpp/include/raft/linalg/dot.cuh +++ b/cpp/include/raft/linalg/dot.cuh @@ -84,35 +84,5 @@ void dot(const raft::handle_t& handle, out.data_handle(), handle.get_stream())); } - -/** - * @brief Computes the dot product of two vectors. - * @param[in] handle raft::handle_t - * @param[in] x First input vector - * @param[in] y Second input vector - * @param[out] out The output dot product between the x and y vectors. - */ -template -void dot(const raft::handle_t& handle, - raft::device_vector_view x, - raft::device_vector_view y, - ElementType* out) -{ - RAFT_EXPECTS(x.size() == y.size(), - "Size mismatch between x and y input vectors in raft::linalg::dot"); - - RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(), - x.size(), - x.data_handle(), - x.stride(0), - y.data_handle(), - y.stride(0), - out, - handle.get_stream())); -} } // namespace raft::linalg #endif From a66809886a0656f9266ca8fd5640d27be87a8a6a Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 8 Nov 2022 15:46:03 -0800 Subject: [PATCH 07/12] Added docs / created 'make_strided_layout' factory function --- cpp/include/raft/core/device_mdspan.hpp | 21 ++++++++++++++++++--- docs/source/quick_start.md | 12 ++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index 470ade3d6e..845e8740e0 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -291,8 +291,25 @@ auto make_device_vector_view( return device_vector_view{ptr, mapping}; } +template +auto make_strided_layout(ExtentType extents, StrideType strides) +{ + return layout_stride::mapping{extents, strides}; +} + /** * @brief Construct a strided vector layout mapping + * + * Usage example: + * @code{.cpp} + * #include + * + * int n_elements = 10; + * int stride = 10; + * auto vector = raft::make_device_vector_view(vector_ptr, + * raft::make_vector_strided_layout(n_elements, stride)); + * @endcode + * * @tparam IndexType the index type of the extents * @params[in] n the number of elements in the vector * @params[in] stride the stride between elements in the vector @@ -300,8 +317,6 @@ auto make_device_vector_view( template auto make_vector_strided_layout(IndexType n, IndexType stride) { - vector_extent exts{n}; - std::array strides{stride}; - return layout_stride::mapping>{exts, strides}; + return make_strided_layout(vector_extent{n}, std::array{stride}); } } // end namespace raft diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index e73f9b8a7a..38ed031759 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -77,6 +77,18 @@ int n_cols = 10; auto matrix = raft::make_managed_mdspan(managed_ptr, raft::make_matrix_extents(n_rows, n_cols)); ``` +You can also create strided mdspans: + +```c++ + +#include + +int n_elements = 10; +int stride = 10; + +auto vector = raft::make_device_vector_view(vector_ptr, raft::make_vector_strided_layout(n_elements, stride)); +``` + ## C++ Example From 977949f6412b73c890fc6bc5e54c0c7072688136 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 8 Nov 2022 16:13:12 -0800 Subject: [PATCH 08/12] Add doxygen for make_strided_layout --- cpp/include/raft/core/device_mdspan.hpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index 845e8740e0..37321dba7b 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -291,8 +291,14 @@ auto make_device_vector_view( return device_vector_view{ptr, mapping}; } -template -auto make_strided_layout(ExtentType extents, StrideType strides) +/** + * @brief Create a layout_stride mapping from extents and strides + * @param[in] The Extents + * @param[in] mapping The layout mapping to use for this vector + * @return raft::layout_stride::mapping + */ +template +auto make_strided_layout(Extents extents, Strides strides) { return layout_stride::mapping{extents, strides}; } @@ -311,8 +317,8 @@ auto make_strided_layout(ExtentType extents, StrideType strides) * @endcode * * @tparam IndexType the index type of the extents - * @params[in] n the number of elements in the vector - * @params[in] stride the stride between elements in the vector + * @param[in] n the number of elements in the vector + * @param[in] stride the stride between elements in the vector */ template auto make_vector_strided_layout(IndexType n, IndexType stride) From a125c4f7ea5b90cdd7b0e81ac7c980553f0a03ce Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 9 Nov 2022 10:29:36 -0800 Subject: [PATCH 09/12] fix --- cpp/include/raft/core/device_mdspan.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index 37321dba7b..17dd8e94d0 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -300,7 +300,7 @@ auto make_device_vector_view( template auto make_strided_layout(Extents extents, Strides strides) { - return layout_stride::mapping{extents, strides}; + return layout_stride::mapping{extents, strides}; } /** From 6f8a76c7c2bbe07a7cd697dfd59c87079273da2d Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 9 Nov 2022 11:04:19 -0800 Subject: [PATCH 10/12] Test out host and device api's --- cpp/test/linalg/dot.cu | 73 +++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/cpp/test/linalg/dot.cu b/cpp/test/linalg/dot.cu index f0bce47189..b5007aea32 100644 --- a/cpp/test/linalg/dot.cu +++ b/cpp/test/linalg/dot.cu @@ -46,25 +46,18 @@ struct DotInputs { template class DotTest : public ::testing::TestWithParam> { protected: - raft::handle_t handle; DotInputs params; - rmm::device_scalar output; - rmm::device_scalar refoutput; + T host_output, device_output, ref_output; public: - DotTest() - : testing::TestWithParam>(), - output(0, handle.get_stream()), - refoutput(0, handle.get_stream()) - { - handle.sync_stream(); - } + DotTest() : testing::TestWithParam>() {} protected: void SetUp() override { params = ::testing::TestWithParam>::GetParam(); + raft::handle_t handle; cudaStream_t stream = handle.get_stream(); raft::random::RngState r(params.seed); @@ -77,36 +70,42 @@ class DotTest : public ::testing::TestWithParam> { uniform(handle, r, x.data(), x_len, T(-1.0), T(1.0)); uniform(handle, r, y.data(), y_len, T(-1.0), T(1.0)); + rmm::device_scalar ref(0, handle.get_stream()); naiveDot<<<256, 256, 0, stream>>>( - params.len, x.data(), params.incx, y.data(), params.incy, refoutput.data()); + params.len, x.data(), params.incx, y.data(), params.incy, ref.data()); + raft::update_host(&ref_output, ref.data(), 1, stream); - auto out_view = make_device_scalar_view(output.data()); + // Test out both the device and host api's + rmm::device_scalar out(0, handle.get_stream()); + auto device_out_view = make_device_scalar_view(out.data()); + auto host_out_view = make_host_scalar_view(&host_output); if ((params.incx > 1) && (params.incy > 1)) { - dot(handle, - make_device_vector_view( - x.data(), make_vector_strided_layout(params.len, params.incx)), - make_device_vector_view( - y.data(), make_vector_strided_layout(params.len, params.incy)), - out_view); + auto x_view = make_device_vector_view( + x.data(), make_vector_strided_layout(params.len, params.incx)); + auto y_view = make_device_vector_view( + y.data(), make_vector_strided_layout(params.len, params.incy)); + dot(handle, x_view, y_view, device_out_view); + dot(handle, x_view, y_view, host_out_view); } else if (params.incx > 1) { - dot(handle, - make_device_vector_view( - x.data(), make_vector_strided_layout(params.len, params.incx)), - make_device_vector_view(y.data(), params.len), - out_view); + auto x_view = make_device_vector_view( + x.data(), make_vector_strided_layout(params.len, params.incx)); + auto y_view = make_device_vector_view(y.data(), params.len); + dot(handle, x_view, y_view, device_out_view); + dot(handle, x_view, y_view, host_out_view); } else if (params.incy > 1) { - dot(handle, - make_device_vector_view(x.data(), params.len), - make_device_vector_view( - y.data(), make_vector_strided_layout(params.len, params.incy)), - out_view); + auto x_view = make_device_vector_view(x.data(), params.len); + auto y_view = make_device_vector_view( + y.data(), make_vector_strided_layout(params.len, params.incy)); + dot(handle, x_view, y_view, device_out_view); + dot(handle, x_view, y_view, host_out_view); } else { - dot(handle, - make_device_vector_view(x.data(), params.len), - make_device_vector_view(y.data(), params.len), - out_view); + auto x_view = make_device_vector_view(x.data(), params.len); + auto y_view = make_device_vector_view(y.data(), params.len); + dot(handle, x_view, y_view, device_out_view); + dot(handle, x_view, y_view, host_out_view); } + raft::update_host(&device_output, out.data(), 1, stream); handle.sync_stream(); } @@ -136,15 +135,17 @@ const std::vector> inputsd = { typedef DotTest DotTestF; TEST_P(DotTestF, Result) { - ASSERT_TRUE(raft::devArrMatch( - refoutput.data(), output.data(), 1, raft::CompareApprox(params.tolerance))); + auto compare = raft::CompareApprox(params.tolerance); + ASSERT_TRUE(compare(ref_output, host_output)); + ASSERT_TRUE(compare(ref_output, device_output)); } typedef DotTest DotTestD; TEST_P(DotTestD, Result) { - ASSERT_TRUE(raft::devArrMatch( - refoutput.data(), output.data(), 1, raft::CompareApprox(params.tolerance))); + auto compare = raft::CompareApprox(params.tolerance); + ASSERT_TRUE(compare(ref_output, host_output)); + ASSERT_TRUE(compare(ref_output, device_output)); } INSTANTIATE_TEST_SUITE_P(DotTests, DotTestF, ::testing::ValuesIn(inputsf)); From 25333eee56da9582115984971b0e2d6a3c8fc962 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 9 Nov 2022 11:43:12 -0800 Subject: [PATCH 11/12] test out both device/host alpha scalar overloads with dot --- cpp/test/linalg/axpy.cu | 90 ++++++++++++++++++++++++++++++----------- 1 file changed, 67 insertions(+), 23 deletions(-) diff --git a/cpp/test/linalg/axpy.cu b/cpp/test/linalg/axpy.cu index fb07e2ec2f..f6cabae012 100644 --- a/cpp/test/linalg/axpy.cu +++ b/cpp/test/linalg/axpy.cu @@ -20,6 +20,8 @@ #include #include +#include + namespace raft { namespace linalg { // Reference axpy implementation. @@ -46,13 +48,15 @@ class AxpyTest : public ::testing::TestWithParam> { raft::handle_t handle; AxpyInputs params; rmm::device_uvector refy; - rmm::device_uvector y; + rmm::device_uvector y_device_alpha; + rmm::device_uvector y_host_alpha; public: AxpyTest() : testing::TestWithParam>(), refy(0, handle.get_stream()), - y(0, handle.get_stream()) + y_host_alpha(0, handle.get_stream()), + y_device_alpha(0, handle.get_stream()) { handle.sync_stream(); } @@ -69,15 +73,17 @@ class AxpyTest : public ::testing::TestWithParam> { IndexType x_len = params.len * params.incx; IndexType y_len = params.len * params.incy; rmm::device_uvector x(x_len, stream); - y.resize(y_len, stream); + y_host_alpha.resize(y_len, stream); + y_device_alpha.resize(y_len, stream); refy.resize(y_len, stream); uniform(handle, r, x.data(), x_len, T(-1.0), T(1.0)); - uniform(handle, r, y.data(), y_len, T(-1.0), T(1.0)); + uniform(handle, r, refy.data(), y_len, T(-1.0), T(1.0)); - // Take a copy of the random generated values in y for the naive reference implementation + // Take a copy of the random generated values in refy // this is necessary since axpy uses y for both input and output - raft::copy(refy.data(), y.data(), y_len, stream); + raft::copy(y_host_alpha.data(), refy.data(), y_len, stream); + raft::copy(y_device_alpha.data(), refy.data(), y_len, stream); int threads = 64; int blocks = raft::ceildiv(params.len, threads); @@ -85,30 +91,58 @@ class AxpyTest : public ::testing::TestWithParam> { naiveAxpy<<>>( params.len, params.alpha, x.data(), refy.data(), params.incx, params.incy); + auto host_alpha_view = make_host_scalar_view(¶ms.alpha); + + // test out both axpy overloads - taking either a host scalar or device scalar view + rmm::device_scalar device_alpha(params.alpha, stream); + auto device_alpha_view = make_device_scalar_view(device_alpha.data()); + if ((params.incx > 1) && (params.incy > 1)) { + auto x_view = make_device_vector_view( + x.data(), make_vector_strided_layout(params.len, params.incx)); + axpy(handle, + host_alpha_view, + x_view, + make_device_vector_view( + y_host_alpha.data(), make_vector_strided_layout(params.len, params.incy))); axpy(handle, - make_host_scalar_view(¶ms.alpha), - make_device_vector_view( - x.data(), make_vector_strided_layout(params.len, params.incx)), + device_alpha_view, + x_view, make_device_vector_view( - y.data(), make_vector_strided_layout(params.len, params.incy))); + y_device_alpha.data(), make_vector_strided_layout(params.len, params.incy))); } else if (params.incx > 1) { + auto x_view = make_device_vector_view( + x.data(), make_vector_strided_layout(params.len, params.incx)); + axpy(handle, + host_alpha_view, + x_view, + make_device_vector_view(y_host_alpha.data(), params.len)); axpy(handle, - make_host_scalar_view(¶ms.alpha), - make_device_vector_view( - x.data(), make_vector_strided_layout(params.len, params.incx)), - make_device_vector_view(y.data(), params.len)); + device_alpha_view, + x_view, + make_device_vector_view(y_device_alpha.data(), params.len)); } else if (params.incy > 1) { + auto x_view = make_device_vector_view(x.data(), params.len); + axpy(handle, + host_alpha_view, + x_view, + make_device_vector_view( + y_host_alpha.data(), make_vector_strided_layout(params.len, params.incy))); axpy(handle, - make_host_scalar_view(¶ms.alpha), - make_device_vector_view(x.data(), params.len), + device_alpha_view, + x_view, make_device_vector_view( - y.data(), make_vector_strided_layout(params.len, params.incy))); + y_device_alpha.data(), make_vector_strided_layout(params.len, params.incy))); } else { + auto x_view = make_device_vector_view(x.data(), params.len); axpy(handle, - make_host_scalar_view(¶ms.alpha), - make_device_vector_view(x.data(), params.len), - make_device_vector_view(y.data(), params.len)); + host_alpha_view, + x_view, + make_device_vector_view(y_host_alpha.data(), params.len)); + axpy(handle, + device_alpha_view, + x_view, + make_device_vector_view(y_device_alpha.data(), params.len)); } handle.sync_stream(); @@ -140,15 +174,25 @@ const std::vector> inputsd = { typedef AxpyTest AxpyTestF; TEST_P(AxpyTestF, Result) { - ASSERT_TRUE(raft::devArrMatch( - refy.data(), y.data(), params.len * params.incy, raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(raft::devArrMatch(refy.data(), + y_host_alpha.data(), + params.len * params.incy, + raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(raft::devArrMatch(refy.data(), + y_device_alpha.data(), + params.len * params.incy, + raft::CompareApprox(params.tolerance))); } typedef AxpyTest AxpyTestD; TEST_P(AxpyTestD, Result) { ASSERT_TRUE(raft::devArrMatch(refy.data(), - y.data(), + y_host_alpha.data(), + params.len * params.incy, + raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(raft::devArrMatch(refy.data(), + y_device_alpha.data(), params.len * params.incy, raft::CompareApprox(params.tolerance))); } From 98f7d85565c85fb8ccd21f8bd283bd8a0518c5c6 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 9 Nov 2022 21:12:40 -0800 Subject: [PATCH 12/12] Fix docstring --- cpp/include/raft/core/device_mdspan.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index b87ce6e44c..3386610224 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -294,8 +294,8 @@ auto make_device_vector_view( /** * @brief Create a layout_stride mapping from extents and strides - * @param[in] The Extents - * @param[in] mapping The layout mapping to use for this vector + * @param[in] extents the dimensionality of the layout + * @param[in] strides the strides between elements in the layout * @return raft::layout_stride::mapping */ template