Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose linalg::dot in public API #968

Merged
merged 14 commits into from
Nov 10, 2022
17 changes: 13 additions & 4 deletions cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,23 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col
* @tparam LayoutPolicy policy for strides and layout ordering
* @param[in] ptr on device to wrap
* @param[in] n number of elements in pointer
* @param[in] stride the stride between consecutive elements in the vector. Setting to a value
* other than 1 requires LayoutPolicy to be set to layout_stride
* @return raft::device_vector_view
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(ElementType* ptr, IndexType n)
auto make_device_vector_view(ElementType* ptr, IndexType n, IndexType stride = 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little awkward. We accept a layout policy as a template argument, but then we also accept a function argument for a stride which essentially overrides the layout from the template.

Would it be achieving this same goal if a user were to just set a strided layout on the template argument directly? Perhaps we could provide a factory function to make said strided layout and provide the user with something like a statically sized object (eg. std::array) to set the strides for each dimension?

An of course, this is one of those things (the new strided factory function) that I think should have a usage example in the doxygen and perhaps even a subsection section in the mdspan tutorial markdown of the docs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding you correctly - you're thinking we can just pass the layout mapping to the make_device_vector_view function directly , and add a new factory function for creating this layout mapping?

I took a stab at that in the last commit - unfortunately, I couldn't get a single make_device_vector_view function to compile successfully with being passed both a IndexType with the number of elements and the Mapping with the strided layout (was getting compile errors in various other raft functions that I hadn't updated). However, I could get it to work with adding an overload - which is whats in the last commit. Do you have any suggestions on how to clean this up =) ?

I'll add something to the tutorial / docs once we're happy with the API -

{
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, n};
if constexpr (std::is_same_v<layout_stride, LayoutPolicy>) {
vector_extent<IndexType> exts{n};
std::array<IndexType, 1> strides{stride};
auto layout = typename LayoutPolicy::template mapping<vector_extent<IndexType>>{exts, strides};
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, layout};
} else {
RAFT_EXPECTS(stride == 1, "Having a stride != 1 requires a layout_stride LayoutPolicy");
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, n};
}
}

} // end namespace raft
} // end namespace raft
69 changes: 32 additions & 37 deletions cpp/include/raft/linalg/axpy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,66 +62,61 @@ void axpy(const raft::handle_t& handle,
* @brief axpy function
* It computes the following equation: y = alpha * x + y
*
* @tparam InType Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param [in] handle raft::handle_t
* @param [in] alpha raft::device_scalar_view
* @param [in] x Input vector
* @param [inout] y Output vector
*/
template <typename InType,
typename OutType,
typename ScalarIdxType,
typename = raft::enable_if_input_device_mdspan<InType>,
typename = raft::enable_if_output_device_mdspan<OutType>>
template <typename ElementType,
typename IndexType = std::uint32_t,
benfred marked this conversation as resolved.
Show resolved Hide resolved
typename InLayoutPolicy = layout_c_contiguous,
typename OutLayoutPolicy = layout_c_contiguous,
typename ScalarIdxType = std::uint32_t>
void axpy(const raft::handle_t& handle,
raft::device_scalar_view<const typename InType::value_type, ScalarIdxType> alpha,
InType x,
OutType y)
raft::device_scalar_view<const ElementType, ScalarIdxType> alpha,
raft::device_vector_view<const ElementType, IndexType, InLayoutPolicy> x,
raft::device_vector_view<ElementType, IndexType, OutLayoutPolicy> y)
{
RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input");

axpy<typename InType::value_type, true>(handle,
y.size(),
alpha.data_handle(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
handle.get_stream());
axpy<ElementType, true>(handle,
y.size(),
alpha.data_handle(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
handle.get_stream());
}

/**
* @brief axpy function
* It computes the following equation: y = alpha * x + y
*
* @tparam MdspanType Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param [in] handle raft::handle_t
* @param [in] alpha raft::device_scalar_view
* @param [in] x Input vector
* @param [inout] y Output vector
*/
template <typename InType,
typename OutType,
typename ScalarIdxType,
typename = raft::enable_if_input_device_mdspan<InType>,
typename = raft::enable_if_output_device_mdspan<OutType>>
template <typename ElementType,
typename IndexType = std::uint32_t,
benfred marked this conversation as resolved.
Show resolved Hide resolved
typename InLayoutPolicy = layout_c_contiguous,
typename OutLayoutPolicy = layout_c_contiguous,
typename ScalarIdxType = std::uint32_t>
void axpy(const raft::handle_t& handle,
raft::host_scalar_view<const typename InType::value_type, ScalarIdxType> alpha,
InType x,
OutType y)
raft::host_scalar_view<const ElementType, ScalarIdxType> alpha,
raft::device_vector_view<const ElementType, IndexType, InLayoutPolicy> x,
raft::device_vector_view<ElementType, IndexType, OutLayoutPolicy> y)
{
RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input");

axpy<typename InType::value_type, false>(handle,
y.size(),
alpha.data_handle(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
handle.get_stream());
axpy<ElementType, false>(handle,
y.size(),
alpha.data_handle(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
handle.get_stream());
}

/** @} */ // end of group axpy
Expand Down
59 changes: 59 additions & 0 deletions cpp/include/raft/linalg/dot.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef __DOT_H
#define __DOT_H

#pragma once

#include <raft/linalg/detail/cublas_wrappers.hpp>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/host_mdspan.hpp>

namespace raft::linalg {
/**
* @brief Computes the dot product of two vectors.
* @param[in] handle raft::handle_t
* @param[in] x First input vector
* @param[in] y Second input vector
* @param[out] out The output dot product between the x and y vectors.
* @note The out parameter can be either a host_scalar_view or device_scalar_view
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
benfred marked this conversation as resolved.
Show resolved Hide resolved
typename ScalarIndexType = std::uint32_t,
typename LayoutPolicy1 = layout_c_contiguous,
typename LayoutPolicy2 = layout_c_contiguous>
void dot(const raft::handle_t& handle,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy1> x,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy2> y,
raft::device_scalar_view<ElementType, ScalarIndexType> out)
benfred marked this conversation as resolved.
Show resolved Hide resolved
{
RAFT_EXPECTS(x.size() == y.size(),
"Size mismatch between x and y input vectors in raft::linalg::dot");

RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(),
x.size(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
out.data_handle(),
handle.get_stream()));
}
} // namespace raft::linalg
#endif
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ if(BUILD_TESTS)
test/linalg/cholesky_r1.cu
test/linalg/coalesced_reduction.cu
test/linalg/divide.cu
test/linalg/dot.cu
test/linalg/eig.cu
test/linalg/eig_sel.cu
test/linalg/gemm_layout.cu
Expand Down
30 changes: 11 additions & 19 deletions cpp/test/linalg/axpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,6 @@ __global__ void naiveAxpy(const int n, const T alpha, const T* x, T* y, int incx
if (idx < n) { y[idx * incy] += alpha * x[idx * incx]; }
}

template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_stride>
auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride)
{
vector_extent<IndexType> exts{n};
std::array<IndexType, 1> strides{stride};
auto layout = typename LayoutPolicy::mapping<vector_extent<IndexType>>{exts, strides};
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, layout};
}

template <typename InType, typename IndexType = int, typename OutType = InType>
struct AxpyInputs {
OutType tolerance;
Expand All @@ -52,11 +41,11 @@ struct AxpyInputs {
unsigned long long int seed;
};

template <typename T>
template <typename T, typename IndexType = int>
class AxpyTest : public ::testing::TestWithParam<AxpyInputs<T>> {
protected:
raft::handle_t handle;
AxpyInputs<T> params;
AxpyInputs<T, IndexType> params;
rmm::device_uvector<T> refy;
rmm::device_uvector<T> y;

Expand All @@ -78,8 +67,8 @@ class AxpyTest : public ::testing::TestWithParam<AxpyInputs<T>> {

raft::random::RngState r(params.seed);

int x_len = params.len * params.incx;
int y_len = params.len * params.incy;
IndexType x_len = params.len * params.incx;
IndexType y_len = params.len * params.incy;
rmm::device_uvector<T> x(x_len, stream);
y.resize(y_len, stream);
refy.resize(y_len, stream);
Expand All @@ -100,18 +89,21 @@ class AxpyTest : public ::testing::TestWithParam<AxpyInputs<T>> {
if ((params.incx > 1) && (params.incy > 1)) {
axpy(handle,
make_host_scalar_view<const T>(&params.alpha),
make_strided_device_vector_view<const T>(x.data(), params.len, params.incx),
make_strided_device_vector_view<T>(y.data(), params.len, params.incy));
make_device_vector_view<const T, IndexType, layout_stride>(
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
x.data(), params.len, params.incx),
make_device_vector_view<T, IndexType, layout_stride>(y.data(), params.len, params.incy));

} else if (params.incx > 1) {
axpy(handle,
make_host_scalar_view<const T>(&params.alpha),
make_strided_device_vector_view<const T>(x.data(), params.len, params.incx),
make_device_vector_view<const T, IndexType, layout_stride>(
x.data(), params.len, params.incx),
make_device_vector_view<T>(y.data(), params.len));
} else if (params.incy > 1) {
axpy(handle,
make_host_scalar_view<const T>(&params.alpha),
make_device_vector_view<const T>(x.data(), params.len),
make_strided_device_vector_view<T>(y.data(), params.len, params.incy));
make_device_vector_view<T, IndexType, layout_stride>(y.data(), params.len, params.incy));
} else {
axpy(handle,
make_host_scalar_view<const T>(&params.alpha),
Expand Down
Loading