Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose linalg::dot in public API #968

Merged
merged 14 commits into from
Nov 10, 2022
49 changes: 34 additions & 15 deletions cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,23 +266,42 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col
* @tparam LayoutPolicy policy for strides and layout ordering
* @param[in] ptr on device to wrap
* @param[in] n number of elements in pointer
* @param[in] stride the stride between consecutive elements in the vector. Setting to a value
* other than 1 requires LayoutPolicy to be set to layout_stride
* @return raft::device_vector_view
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(ElementType* ptr, IndexType n, IndexType stride = 1)
template <typename ElementType, typename IndexType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(ElementType* ptr, IndexType n)
{
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, n};
}

/**
* @brief Create a 1-dim mdspan instance for device pointer.
* @tparam ElementType the data type of the vector elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param[in] ptr on device to wrap
* @param[in] mapping The layout mapping to use for this vector
* @return raft::device_vector_view
*/
template <typename ElementType, typename IndexType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(
ElementType* ptr,
const typename LayoutPolicy::template mapping<vector_extent<IndexType>>& mapping)
{
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, mapping};
}

/**
* @brief Construct a strided vector layout mapping
* @tparam IndexType the index type of the extents
* @params[in] n the number of elements in the vector
* @params[in] stride the stride between elements in the vector
*/
template <typename IndexType>
auto make_vector_strided_layout(IndexType n, IndexType stride)
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
{
if constexpr (std::is_same_v<layout_stride, LayoutPolicy>) {
vector_extent<IndexType> exts{n};
std::array<IndexType, 1> strides{stride};
auto layout = typename LayoutPolicy::template mapping<vector_extent<IndexType>>{exts, strides};
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, layout};
} else {
RAFT_EXPECTS(stride == 1, "Having a stride != 1 requires a layout_stride LayoutPolicy");
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, n};
}
vector_extent<IndexType> exts{n};
std::array<IndexType, 1> strides{stride};
return layout_stride::mapping<vector_extent<IndexType>>{exts, strides};
}
} // end namespace raft
16 changes: 8 additions & 8 deletions cpp/include/raft/linalg/axpy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ void axpy(const raft::handle_t& handle,
* @param [inout] y Output vector
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename InLayoutPolicy = layout_c_contiguous,
typename OutLayoutPolicy = layout_c_contiguous,
typename ScalarIdxType = std::uint32_t>
typename IndexType,
typename InLayoutPolicy,
typename OutLayoutPolicy,
typename ScalarIdxType>
void axpy(const raft::handle_t& handle,
raft::device_scalar_view<const ElementType, ScalarIdxType> alpha,
raft::device_vector_view<const ElementType, IndexType, InLayoutPolicy> x,
Expand All @@ -98,10 +98,10 @@ void axpy(const raft::handle_t& handle,
* @param [inout] y Output vector
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename InLayoutPolicy = layout_c_contiguous,
typename OutLayoutPolicy = layout_c_contiguous,
typename ScalarIdxType = std::uint32_t>
typename IndexType,
typename InLayoutPolicy,
typename OutLayoutPolicy,
typename ScalarIdxType>
void axpy(const raft::handle_t& handle,
raft::host_scalar_view<const ElementType, ScalarIdxType> alpha,
raft::device_vector_view<const ElementType, IndexType, InLayoutPolicy> x,
Expand Down
69 changes: 64 additions & 5 deletions cpp/include/raft/linalg/dot.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ namespace raft::linalg {
* @param[in] x First input vector
* @param[in] y Second input vector
* @param[out] out The output dot product between the x and y vectors.
* @note The out parameter can be either a host_scalar_view or device_scalar_view
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename ScalarIndexType = std::uint32_t,
typename LayoutPolicy1 = layout_c_contiguous,
typename LayoutPolicy2 = layout_c_contiguous>
typename IndexType,
typename ScalarIndexType,
typename LayoutPolicy1,
typename LayoutPolicy2>
void dot(const raft::handle_t& handle,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy1> x,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy2> y,
Expand All @@ -55,5 +54,65 @@ void dot(const raft::handle_t& handle,
out.data_handle(),
handle.get_stream()));
}

/**
* @brief Computes the dot product of two vectors.
* @param[in] handle raft::handle_t
* @param[in] x First input vector
* @param[in] y Second input vector
* @param[out] out The output dot product between the x and y vectors.
*/
template <typename ElementType,
typename IndexType,
typename ScalarIndexType,
typename LayoutPolicy1,
typename LayoutPolicy2>
void dot(const raft::handle_t& handle,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy1> x,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy2> y,
raft::host_scalar_view<ElementType, ScalarIndexType> out)
{
RAFT_EXPECTS(x.size() == y.size(),
"Size mismatch between x and y input vectors in raft::linalg::dot");

RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(),
x.size(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
out.data_handle(),
handle.get_stream()));
}

/**
* @brief Computes the dot product of two vectors.
* @param[in] handle raft::handle_t
* @param[in] x First input vector
* @param[in] y Second input vector
* @param[out] out The output dot product between the x and y vectors.
*/
template <typename ElementType,
typename IndexType,
typename ScalarIndexType,
typename LayoutPolicy1,
typename LayoutPolicy2>
void dot(const raft::handle_t& handle,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy1> x,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy2> y,
ElementType* out)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the host output, we probably should drop this overload. Sorry for being confusing here. I think it makes more sense to accept a host scalar by value for functions like axpy where the scalar is an input. For output on host, I think we should stick to the mdspan scalar wrappers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in latest commit

{
RAFT_EXPECTS(x.size() == y.size(),
"Size mismatch between x and y input vectors in raft::linalg::dot");

RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(),
x.size(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
out,
handle.get_stream()));
}
} // namespace raft::linalg
#endif
14 changes: 7 additions & 7 deletions cpp/test/linalg/axpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

namespace raft {
namespace linalg {

// Reference axpy implementation.
template <typename T>
__global__ void naiveAxpy(const int n, const T alpha, const T* x, T* y, int incx, int incy)
Expand Down Expand Up @@ -90,20 +89,21 @@ class AxpyTest : public ::testing::TestWithParam<AxpyInputs<T>> {
axpy(handle,
make_host_scalar_view<const T>(&params.alpha),
make_device_vector_view<const T, IndexType, layout_stride>(
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
x.data(), params.len, params.incx),
make_device_vector_view<T, IndexType, layout_stride>(y.data(), params.len, params.incy));

x.data(), make_vector_strided_layout<IndexType>(params.len, params.incx)),
make_device_vector_view<T, IndexType, layout_stride>(
y.data(), make_vector_strided_layout(params.len, params.incy)));
} else if (params.incx > 1) {
axpy(handle,
make_host_scalar_view<const T>(&params.alpha),
make_device_vector_view<const T, IndexType, layout_stride>(
x.data(), params.len, params.incx),
make_device_vector_view<T>(y.data(), params.len));
x.data(), make_vector_strided_layout(params.len, params.incx)),
make_device_vector_view<T, IndexType>(y.data(), params.len));
} else if (params.incy > 1) {
axpy(handle,
make_host_scalar_view<const T>(&params.alpha),
make_device_vector_view<const T>(x.data(), params.len),
make_device_vector_view<T, IndexType, layout_stride>(y.data(), params.len, params.incy));
make_device_vector_view<T, IndexType, layout_stride>(
y.data(), make_vector_strided_layout(params.len, params.incy)));
} else {
axpy(handle,
make_host_scalar_view<const T>(&params.alpha),
Expand Down
8 changes: 4 additions & 4 deletions cpp/test/linalg/dot.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,21 @@ class DotTest : public ::testing::TestWithParam<DotInputs<T>> {
if ((params.incx > 1) && (params.incy > 1)) {
dot(handle,
make_device_vector_view<const T, IndexType, layout_stride>(
x.data(), params.len, params.incx),
x.data(), make_vector_strided_layout(params.len, params.incx)),
make_device_vector_view<const T, IndexType, layout_stride>(
y.data(), params.len, params.incy),
y.data(), make_vector_strided_layout(params.len, params.incy)),
out_view);
} else if (params.incx > 1) {
dot(handle,
make_device_vector_view<const T, IndexType, layout_stride>(
x.data(), params.len, params.incx),
x.data(), make_vector_strided_layout(params.len, params.incx)),
make_device_vector_view<const T>(y.data(), params.len),
out_view);
} else if (params.incy > 1) {
dot(handle,
make_device_vector_view<const T>(x.data(), params.len),
make_device_vector_view<const T, IndexType, layout_stride>(
y.data(), params.len, params.incy),
y.data(), make_vector_strided_layout(params.len, params.incy)),
out_view);
} else {
dot(handle,
Expand Down