diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index ae66f315d9..3386610224 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -269,12 +269,61 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col * @param[in] n number of elements in pointer * @return raft::device_vector_view */ -template +template auto make_device_vector_view(ElementType* ptr, IndexType n) { return device_vector_view{ptr, n}; } +/** + * @brief Create a 1-dim mdspan instance for device pointer. + * @tparam ElementType the data type of the vector elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr on device to wrap + * @param[in] mapping The layout mapping to use for this vector + * @return raft::device_vector_view + */ +template +auto make_device_vector_view( + ElementType* ptr, + const typename LayoutPolicy::template mapping>& mapping) +{ + return device_vector_view{ptr, mapping}; +} + +/** + * @brief Create a layout_stride mapping from extents and strides + * @param[in] extents the dimensionality of the layout + * @param[in] strides the strides between elements in the layout + * @return raft::layout_stride::mapping + */ +template +auto make_strided_layout(Extents extents, Strides strides) +{ + return layout_stride::mapping{extents, strides}; +} + +/** + * @brief Construct a strided vector layout mapping + * + * Usage example: + * @code{.cpp} + * #include + * + * int n_elements = 10; + * int stride = 10; + * auto vector = raft::make_device_vector_view(vector_ptr, + * raft::make_vector_strided_layout(n_elements, stride)); + * @endcode + * + * @tparam IndexType the index type of the extents + * @param[in] n the number of elements in the vector + * @param[in] stride the stride between elements in the vector + */ +template +auto make_vector_strided_layout(IndexType n, IndexType stride) +{ + return make_strided_layout(vector_extent{n}, std::array{stride}); +} } // end namespace raft diff --git a/cpp/include/raft/linalg/axpy.cuh b/cpp/include/raft/linalg/axpy.cuh index d3abfed6e6..88b065c8b0 100644 --- a/cpp/include/raft/linalg/axpy.cuh +++ b/cpp/include/raft/linalg/axpy.cuh @@ -62,66 +62,61 @@ void axpy(const raft::handle_t& handle, * @brief axpy function * It computes the following equation: y = alpha * x + y * - * @tparam InType Type raft::device_mdspan - * @tparam ScalarIdxType Index Type of scalar * @param [in] handle raft::handle_t * @param [in] alpha raft::device_scalar_view * @param [in] x Input vector * @param [inout] y Output vector */ -template , - typename = raft::enable_if_output_device_mdspan> +template void axpy(const raft::handle_t& handle, - raft::device_scalar_view alpha, - InType x, - OutType y) + raft::device_scalar_view alpha, + raft::device_vector_view x, + raft::device_vector_view y) { RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input"); - axpy(handle, - y.size(), - alpha.data_handle(), - x.data_handle(), - x.stride(0), - y.data_handle(), - y.stride(0), - handle.get_stream()); + axpy(handle, + y.size(), + alpha.data_handle(), + x.data_handle(), + x.stride(0), + y.data_handle(), + y.stride(0), + handle.get_stream()); } /** * @brief axpy function * It computes the following equation: y = alpha * x + y - * - * @tparam MdspanType Type raft::device_mdspan - * @tparam ScalarIdxType Index Type of scalar * @param [in] handle raft::handle_t * @param [in] alpha raft::device_scalar_view * @param [in] x Input vector * @param [inout] y Output vector */ -template , - typename = raft::enable_if_output_device_mdspan> +template void axpy(const raft::handle_t& handle, - raft::host_scalar_view alpha, - InType x, - OutType y) + raft::host_scalar_view alpha, + raft::device_vector_view x, + raft::device_vector_view y) { RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input"); - axpy(handle, - y.size(), - alpha.data_handle(), - x.data_handle(), - x.stride(0), - y.data_handle(), - y.stride(0), - handle.get_stream()); + axpy(handle, + y.size(), + alpha.data_handle(), + x.data_handle(), + x.stride(0), + y.data_handle(), + y.stride(0), + handle.get_stream()); } /** @} */ // end of group axpy diff --git a/cpp/include/raft/linalg/dot.cuh b/cpp/include/raft/linalg/dot.cuh new file mode 100644 index 0000000000..48577650bc --- /dev/null +++ b/cpp/include/raft/linalg/dot.cuh @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef __DOT_H +#define __DOT_H + +#pragma once + +#include + +#include +#include +#include + +namespace raft::linalg { +/** + * @brief Computes the dot product of two vectors. + * @param[in] handle raft::handle_t + * @param[in] x First input vector + * @param[in] y Second input vector + * @param[out] out The output dot product between the x and y vectors. + */ +template +void dot(const raft::handle_t& handle, + raft::device_vector_view x, + raft::device_vector_view y, + raft::device_scalar_view out) +{ + RAFT_EXPECTS(x.size() == y.size(), + "Size mismatch between x and y input vectors in raft::linalg::dot"); + + RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(), + x.size(), + x.data_handle(), + x.stride(0), + y.data_handle(), + y.stride(0), + out.data_handle(), + handle.get_stream())); +} + +/** + * @brief Computes the dot product of two vectors. + * @param[in] handle raft::handle_t + * @param[in] x First input vector + * @param[in] y Second input vector + * @param[out] out The output dot product between the x and y vectors. + */ +template +void dot(const raft::handle_t& handle, + raft::device_vector_view x, + raft::device_vector_view y, + raft::host_scalar_view out) +{ + RAFT_EXPECTS(x.size() == y.size(), + "Size mismatch between x and y input vectors in raft::linalg::dot"); + + RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(), + x.size(), + x.data_handle(), + x.stride(0), + y.data_handle(), + y.stride(0), + out.data_handle(), + handle.get_stream())); +} +} // namespace raft::linalg +#endif diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 792bcf1ec1..0f5ebabcb9 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -139,6 +139,7 @@ if(BUILD_TESTS) test/linalg/cholesky_r1.cu test/linalg/coalesced_reduction.cu test/linalg/divide.cu + test/linalg/dot.cu test/linalg/eig.cu test/linalg/eig_sel.cu test/linalg/gemm_layout.cu diff --git a/cpp/test/linalg/axpy.cu b/cpp/test/linalg/axpy.cu index fdab845914..f6cabae012 100644 --- a/cpp/test/linalg/axpy.cu +++ b/cpp/test/linalg/axpy.cu @@ -20,9 +20,10 @@ #include #include +#include + namespace raft { namespace linalg { - // Reference axpy implementation. template __global__ void naiveAxpy(const int n, const T alpha, const T* x, T* y, int incx, int incy) @@ -31,17 +32,6 @@ __global__ void naiveAxpy(const int n, const T alpha, const T* x, T* y, int incx if (idx < n) { y[idx * incy] += alpha * x[idx * incx]; } } -template -auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride) -{ - vector_extent exts{n}; - std::array strides{stride}; - auto layout = typename LayoutPolicy::mapping>{exts, strides}; - return device_vector_view{ptr, layout}; -} - template struct AxpyInputs { OutType tolerance; @@ -52,19 +42,21 @@ struct AxpyInputs { unsigned long long int seed; }; -template +template class AxpyTest : public ::testing::TestWithParam> { protected: raft::handle_t handle; - AxpyInputs params; + AxpyInputs params; rmm::device_uvector refy; - rmm::device_uvector y; + rmm::device_uvector y_device_alpha; + rmm::device_uvector y_host_alpha; public: AxpyTest() : testing::TestWithParam>(), refy(0, handle.get_stream()), - y(0, handle.get_stream()) + y_host_alpha(0, handle.get_stream()), + y_device_alpha(0, handle.get_stream()) { handle.sync_stream(); } @@ -78,18 +70,20 @@ class AxpyTest : public ::testing::TestWithParam> { raft::random::RngState r(params.seed); - int x_len = params.len * params.incx; - int y_len = params.len * params.incy; + IndexType x_len = params.len * params.incx; + IndexType y_len = params.len * params.incy; rmm::device_uvector x(x_len, stream); - y.resize(y_len, stream); + y_host_alpha.resize(y_len, stream); + y_device_alpha.resize(y_len, stream); refy.resize(y_len, stream); uniform(handle, r, x.data(), x_len, T(-1.0), T(1.0)); - uniform(handle, r, y.data(), y_len, T(-1.0), T(1.0)); + uniform(handle, r, refy.data(), y_len, T(-1.0), T(1.0)); - // Take a copy of the random generated values in y for the naive reference implementation + // Take a copy of the random generated values in refy // this is necessary since axpy uses y for both input and output - raft::copy(refy.data(), y.data(), y_len, stream); + raft::copy(y_host_alpha.data(), refy.data(), y_len, stream); + raft::copy(y_device_alpha.data(), refy.data(), y_len, stream); int threads = 64; int blocks = raft::ceildiv(params.len, threads); @@ -97,26 +91,58 @@ class AxpyTest : public ::testing::TestWithParam> { naiveAxpy<<>>( params.len, params.alpha, x.data(), refy.data(), params.incx, params.incy); + auto host_alpha_view = make_host_scalar_view(¶ms.alpha); + + // test out both axpy overloads - taking either a host scalar or device scalar view + rmm::device_scalar device_alpha(params.alpha, stream); + auto device_alpha_view = make_device_scalar_view(device_alpha.data()); + if ((params.incx > 1) && (params.incy > 1)) { + auto x_view = make_device_vector_view( + x.data(), make_vector_strided_layout(params.len, params.incx)); axpy(handle, - make_host_scalar_view(¶ms.alpha), - make_strided_device_vector_view(x.data(), params.len, params.incx), - make_strided_device_vector_view(y.data(), params.len, params.incy)); + host_alpha_view, + x_view, + make_device_vector_view( + y_host_alpha.data(), make_vector_strided_layout(params.len, params.incy))); + axpy(handle, + device_alpha_view, + x_view, + make_device_vector_view( + y_device_alpha.data(), make_vector_strided_layout(params.len, params.incy))); } else if (params.incx > 1) { + auto x_view = make_device_vector_view( + x.data(), make_vector_strided_layout(params.len, params.incx)); + axpy(handle, + host_alpha_view, + x_view, + make_device_vector_view(y_host_alpha.data(), params.len)); axpy(handle, - make_host_scalar_view(¶ms.alpha), - make_strided_device_vector_view(x.data(), params.len, params.incx), - make_device_vector_view(y.data(), params.len)); + device_alpha_view, + x_view, + make_device_vector_view(y_device_alpha.data(), params.len)); } else if (params.incy > 1) { + auto x_view = make_device_vector_view(x.data(), params.len); axpy(handle, - make_host_scalar_view(¶ms.alpha), - make_device_vector_view(x.data(), params.len), - make_strided_device_vector_view(y.data(), params.len, params.incy)); + host_alpha_view, + x_view, + make_device_vector_view( + y_host_alpha.data(), make_vector_strided_layout(params.len, params.incy))); + axpy(handle, + device_alpha_view, + x_view, + make_device_vector_view( + y_device_alpha.data(), make_vector_strided_layout(params.len, params.incy))); } else { + auto x_view = make_device_vector_view(x.data(), params.len); + axpy(handle, + host_alpha_view, + x_view, + make_device_vector_view(y_host_alpha.data(), params.len)); axpy(handle, - make_host_scalar_view(¶ms.alpha), - make_device_vector_view(x.data(), params.len), - make_device_vector_view(y.data(), params.len)); + device_alpha_view, + x_view, + make_device_vector_view(y_device_alpha.data(), params.len)); } handle.sync_stream(); @@ -148,15 +174,25 @@ const std::vector> inputsd = { typedef AxpyTest AxpyTestF; TEST_P(AxpyTestF, Result) { - ASSERT_TRUE(raft::devArrMatch( - refy.data(), y.data(), params.len * params.incy, raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(raft::devArrMatch(refy.data(), + y_host_alpha.data(), + params.len * params.incy, + raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(raft::devArrMatch(refy.data(), + y_device_alpha.data(), + params.len * params.incy, + raft::CompareApprox(params.tolerance))); } typedef AxpyTest AxpyTestD; TEST_P(AxpyTestD, Result) { ASSERT_TRUE(raft::devArrMatch(refy.data(), - y.data(), + y_host_alpha.data(), + params.len * params.incy, + raft::CompareApprox(params.tolerance))); + ASSERT_TRUE(raft::devArrMatch(refy.data(), + y_device_alpha.data(), params.len * params.incy, raft::CompareApprox(params.tolerance))); } diff --git a/cpp/test/linalg/dot.cu b/cpp/test/linalg/dot.cu new file mode 100644 index 0000000000..b5007aea32 --- /dev/null +++ b/cpp/test/linalg/dot.cu @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "../test_utils.h" +#include +#include +#include +#include + +namespace raft { +namespace linalg { +// Reference dot implementation. +template +__global__ void naiveDot(const int n, const T* x, int incx, const T* y, int incy, T* out) +{ + T sum = 0; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + sum += x[i * incx] * y[i * incy]; + } + atomicAdd(out, sum); +} + +template +struct DotInputs { + OutType tolerance; + IndexType len; + IndexType incx; + IndexType incy; + unsigned long long int seed; +}; + +template +class DotTest : public ::testing::TestWithParam> { + protected: + DotInputs params; + T host_output, device_output, ref_output; + + public: + DotTest() : testing::TestWithParam>() {} + + protected: + void SetUp() override + { + params = ::testing::TestWithParam>::GetParam(); + + raft::handle_t handle; + cudaStream_t stream = handle.get_stream(); + + raft::random::RngState r(params.seed); + + IndexType x_len = params.len * params.incx; + IndexType y_len = params.len * params.incy; + + rmm::device_uvector x(x_len, stream); + rmm::device_uvector y(y_len, stream); + uniform(handle, r, x.data(), x_len, T(-1.0), T(1.0)); + uniform(handle, r, y.data(), y_len, T(-1.0), T(1.0)); + + rmm::device_scalar ref(0, handle.get_stream()); + naiveDot<<<256, 256, 0, stream>>>( + params.len, x.data(), params.incx, y.data(), params.incy, ref.data()); + raft::update_host(&ref_output, ref.data(), 1, stream); + + // Test out both the device and host api's + rmm::device_scalar out(0, handle.get_stream()); + auto device_out_view = make_device_scalar_view(out.data()); + auto host_out_view = make_host_scalar_view(&host_output); + + if ((params.incx > 1) && (params.incy > 1)) { + auto x_view = make_device_vector_view( + x.data(), make_vector_strided_layout(params.len, params.incx)); + auto y_view = make_device_vector_view( + y.data(), make_vector_strided_layout(params.len, params.incy)); + dot(handle, x_view, y_view, device_out_view); + dot(handle, x_view, y_view, host_out_view); + } else if (params.incx > 1) { + auto x_view = make_device_vector_view( + x.data(), make_vector_strided_layout(params.len, params.incx)); + auto y_view = make_device_vector_view(y.data(), params.len); + dot(handle, x_view, y_view, device_out_view); + dot(handle, x_view, y_view, host_out_view); + } else if (params.incy > 1) { + auto x_view = make_device_vector_view(x.data(), params.len); + auto y_view = make_device_vector_view( + y.data(), make_vector_strided_layout(params.len, params.incy)); + dot(handle, x_view, y_view, device_out_view); + dot(handle, x_view, y_view, host_out_view); + } else { + auto x_view = make_device_vector_view(x.data(), params.len); + auto y_view = make_device_vector_view(y.data(), params.len); + dot(handle, x_view, y_view, device_out_view); + dot(handle, x_view, y_view, host_out_view); + } + raft::update_host(&device_output, out.data(), 1, stream); + handle.sync_stream(); + } + + void TearDown() override {} +}; + +const std::vector> inputsf = { + {0.0001f, 1024 * 1024, 1, 1, 1234ULL}, + {0.0001f, 16 * 1024 * 1024, 1, 1, 1234ULL}, + {0.0001f, 98689, 1, 1, 1234ULL}, + {0.0001f, 4 * 1024 * 1024, 1, 1, 1234ULL}, + {0.0001f, 1024 * 1024, 4, 1, 1234ULL}, + {0.0001f, 1024 * 1024, 1, 3, 1234ULL}, + {0.0001f, 1024 * 1024, 4, 3, 1234ULL}, +}; + +const std::vector> inputsd = { + {0.000001f, 1024 * 1024, 1, 1, 1234ULL}, + {0.000001f, 16 * 1024 * 1024, 1, 1, 1234ULL}, + {0.000001f, 98689, 1, 1, 1234ULL}, + {0.000001f, 4 * 1024 * 1024, 1, 1, 1234ULL}, + {0.000001f, 1024 * 1024, 4, 1, 1234ULL}, + {0.000001f, 1024 * 1024, 1, 3, 1234ULL}, + {0.000001f, 1024 * 1024, 4, 3, 1234ULL}, +}; + +typedef DotTest DotTestF; +TEST_P(DotTestF, Result) +{ + auto compare = raft::CompareApprox(params.tolerance); + ASSERT_TRUE(compare(ref_output, host_output)); + ASSERT_TRUE(compare(ref_output, device_output)); +} + +typedef DotTest DotTestD; +TEST_P(DotTestD, Result) +{ + auto compare = raft::CompareApprox(params.tolerance); + ASSERT_TRUE(compare(ref_output, host_output)); + ASSERT_TRUE(compare(ref_output, device_output)); +} + +INSTANTIATE_TEST_SUITE_P(DotTests, DotTestF, ::testing::ValuesIn(inputsf)); + +INSTANTIATE_TEST_SUITE_P(DotTests, DotTestD, ::testing::ValuesIn(inputsd)); + +} // end namespace linalg +} // end namespace raft diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index e73f9b8a7a..38ed031759 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -77,6 +77,18 @@ int n_cols = 10; auto matrix = raft::make_managed_mdspan(managed_ptr, raft::make_matrix_extents(n_rows, n_cols)); ``` +You can also create strided mdspans: + +```c++ + +#include + +int n_elements = 10; +int stride = 10; + +auto vector = raft::make_device_vector_view(vector_ptr, raft::make_vector_strided_layout(n_elements, stride)); +``` + ## C++ Example