Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose linalg::dot in public API #968

Merged
merged 14 commits into from
Nov 10, 2022
36 changes: 32 additions & 4 deletions cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,12 +268,40 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col
* @param[in] n number of elements in pointer
* @return raft::device_vector_view
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous>
template <typename ElementType, typename IndexType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(ElementType* ptr, IndexType n)
{
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, n};
}

} // end namespace raft
/**
* @brief Create a 1-dim mdspan instance for device pointer.
* @tparam ElementType the data type of the vector elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param[in] ptr on device to wrap
* @param[in] mapping The layout mapping to use for this vector
* @return raft::device_vector_view
*/
template <typename ElementType, typename IndexType, typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(
ElementType* ptr,
const typename LayoutPolicy::template mapping<vector_extent<IndexType>>& mapping)
{
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, mapping};
}

/**
* @brief Construct a strided vector layout mapping
* @tparam IndexType the index type of the extents
* @params[in] n the number of elements in the vector
* @params[in] stride the stride between elements in the vector
*/
template <typename IndexType>
auto make_vector_strided_layout(IndexType n, IndexType stride)
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
{
vector_extent<IndexType> exts{n};
std::array<IndexType, 1> strides{stride};
return layout_stride::mapping<vector_extent<IndexType>>{exts, strides};
}
} // end namespace raft
69 changes: 32 additions & 37 deletions cpp/include/raft/linalg/axpy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,66 +62,61 @@ void axpy(const raft::handle_t& handle,
* @brief axpy function
* It computes the following equation: y = alpha * x + y
*
* @tparam InType Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param [in] handle raft::handle_t
* @param [in] alpha raft::device_scalar_view
* @param [in] x Input vector
* @param [inout] y Output vector
*/
template <typename InType,
typename OutType,
typename ScalarIdxType,
typename = raft::enable_if_input_device_mdspan<InType>,
typename = raft::enable_if_output_device_mdspan<OutType>>
template <typename ElementType,
typename IndexType,
typename InLayoutPolicy,
typename OutLayoutPolicy,
typename ScalarIdxType>
void axpy(const raft::handle_t& handle,
raft::device_scalar_view<const typename InType::value_type, ScalarIdxType> alpha,
InType x,
OutType y)
raft::device_scalar_view<const ElementType, ScalarIdxType> alpha,
raft::device_vector_view<const ElementType, IndexType, InLayoutPolicy> x,
raft::device_vector_view<ElementType, IndexType, OutLayoutPolicy> y)
{
RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input");

axpy<typename InType::value_type, true>(handle,
y.size(),
alpha.data_handle(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
handle.get_stream());
axpy<ElementType, true>(handle,
y.size(),
alpha.data_handle(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
handle.get_stream());
}

/**
* @brief axpy function
* It computes the following equation: y = alpha * x + y
*
* @tparam MdspanType Type raft::device_mdspan
* @tparam ScalarIdxType Index Type of scalar
* @param [in] handle raft::handle_t
* @param [in] alpha raft::device_scalar_view
* @param [in] x Input vector
* @param [inout] y Output vector
*/
template <typename InType,
typename OutType,
typename ScalarIdxType,
typename = raft::enable_if_input_device_mdspan<InType>,
typename = raft::enable_if_output_device_mdspan<OutType>>
template <typename ElementType,
typename IndexType,
typename InLayoutPolicy,
typename OutLayoutPolicy,
typename ScalarIdxType>
void axpy(const raft::handle_t& handle,
raft::host_scalar_view<const typename InType::value_type, ScalarIdxType> alpha,
InType x,
OutType y)
raft::host_scalar_view<const ElementType, ScalarIdxType> alpha,
raft::device_vector_view<const ElementType, IndexType, InLayoutPolicy> x,
raft::device_vector_view<ElementType, IndexType, OutLayoutPolicy> y)
{
RAFT_EXPECTS(y.size() == x.size(), "Size mismatch between Output and Input");

axpy<typename InType::value_type, false>(handle,
y.size(),
alpha.data_handle(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
handle.get_stream());
axpy<ElementType, false>(handle,
y.size(),
alpha.data_handle(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
handle.get_stream());
}

/** @} */ // end of group axpy
Expand Down
118 changes: 118 additions & 0 deletions cpp/include/raft/linalg/dot.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef __DOT_H
#define __DOT_H

#pragma once

#include <raft/linalg/detail/cublas_wrappers.hpp>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/host_mdspan.hpp>

namespace raft::linalg {
/**
* @brief Computes the dot product of two vectors.
* @param[in] handle raft::handle_t
* @param[in] x First input vector
* @param[in] y Second input vector
* @param[out] out The output dot product between the x and y vectors.
*/
template <typename ElementType,
typename IndexType,
typename ScalarIndexType,
typename LayoutPolicy1,
typename LayoutPolicy2>
void dot(const raft::handle_t& handle,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy1> x,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy2> y,
raft::device_scalar_view<ElementType, ScalarIndexType> out)
benfred marked this conversation as resolved.
Show resolved Hide resolved
{
RAFT_EXPECTS(x.size() == y.size(),
"Size mismatch between x and y input vectors in raft::linalg::dot");

RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(),
x.size(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
out.data_handle(),
handle.get_stream()));
}

/**
* @brief Computes the dot product of two vectors.
* @param[in] handle raft::handle_t
* @param[in] x First input vector
* @param[in] y Second input vector
* @param[out] out The output dot product between the x and y vectors.
*/
template <typename ElementType,
typename IndexType,
typename ScalarIndexType,
typename LayoutPolicy1,
typename LayoutPolicy2>
void dot(const raft::handle_t& handle,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy1> x,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy2> y,
raft::host_scalar_view<ElementType, ScalarIndexType> out)
{
RAFT_EXPECTS(x.size() == y.size(),
"Size mismatch between x and y input vectors in raft::linalg::dot");

RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(),
x.size(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
out.data_handle(),
handle.get_stream()));
}

/**
* @brief Computes the dot product of two vectors.
* @param[in] handle raft::handle_t
* @param[in] x First input vector
* @param[in] y Second input vector
* @param[out] out The output dot product between the x and y vectors.
*/
template <typename ElementType,
typename IndexType,
typename ScalarIndexType,
typename LayoutPolicy1,
typename LayoutPolicy2>
void dot(const raft::handle_t& handle,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy1> x,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy2> y,
ElementType* out)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the host output, we probably should drop this overload. Sorry for being confusing here. I think it makes more sense to accept a host scalar by value for functions like axpy where the scalar is an input. For output on host, I think we should stick to the mdspan scalar wrappers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in latest commit

{
RAFT_EXPECTS(x.size() == y.size(),
"Size mismatch between x and y input vectors in raft::linalg::dot");

RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(),
x.size(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
out,
handle.get_stream()));
}
} // namespace raft::linalg
#endif
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ if(BUILD_TESTS)
test/linalg/cholesky_r1.cu
test/linalg/coalesced_reduction.cu
test/linalg/divide.cu
test/linalg/dot.cu
test/linalg/eig.cu
test/linalg/eig_sel.cu
test/linalg/gemm_layout.cu
Expand Down
34 changes: 13 additions & 21 deletions cpp/test/linalg/axpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

namespace raft {
namespace linalg {

// Reference axpy implementation.
template <typename T>
__global__ void naiveAxpy(const int n, const T alpha, const T* x, T* y, int incx, int incy)
Expand All @@ -31,17 +30,6 @@ __global__ void naiveAxpy(const int n, const T alpha, const T* x, T* y, int incx
if (idx < n) { y[idx * incy] += alpha * x[idx * incx]; }
}

template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_stride>
auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride)
{
vector_extent<IndexType> exts{n};
std::array<IndexType, 1> strides{stride};
auto layout = typename LayoutPolicy::mapping<vector_extent<IndexType>>{exts, strides};
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, layout};
}

template <typename InType, typename IndexType = int, typename OutType = InType>
struct AxpyInputs {
OutType tolerance;
Expand All @@ -52,11 +40,11 @@ struct AxpyInputs {
unsigned long long int seed;
};

template <typename T>
template <typename T, typename IndexType = int>
class AxpyTest : public ::testing::TestWithParam<AxpyInputs<T>> {
protected:
raft::handle_t handle;
AxpyInputs<T> params;
AxpyInputs<T, IndexType> params;
rmm::device_uvector<T> refy;
rmm::device_uvector<T> y;

Expand All @@ -78,8 +66,8 @@ class AxpyTest : public ::testing::TestWithParam<AxpyInputs<T>> {

raft::random::RngState r(params.seed);

int x_len = params.len * params.incx;
int y_len = params.len * params.incy;
IndexType x_len = params.len * params.incx;
IndexType y_len = params.len * params.incy;
rmm::device_uvector<T> x(x_len, stream);
y.resize(y_len, stream);
refy.resize(y_len, stream);
Expand All @@ -100,18 +88,22 @@ class AxpyTest : public ::testing::TestWithParam<AxpyInputs<T>> {
if ((params.incx > 1) && (params.incy > 1)) {
axpy(handle,
make_host_scalar_view<const T>(&params.alpha),
make_strided_device_vector_view<const T>(x.data(), params.len, params.incx),
make_strided_device_vector_view<T>(y.data(), params.len, params.incy));
make_device_vector_view<const T, IndexType, layout_stride>(
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
x.data(), make_vector_strided_layout<IndexType>(params.len, params.incx)),
make_device_vector_view<T, IndexType, layout_stride>(
y.data(), make_vector_strided_layout(params.len, params.incy)));
} else if (params.incx > 1) {
axpy(handle,
make_host_scalar_view<const T>(&params.alpha),
make_strided_device_vector_view<const T>(x.data(), params.len, params.incx),
make_device_vector_view<T>(y.data(), params.len));
make_device_vector_view<const T, IndexType, layout_stride>(
x.data(), make_vector_strided_layout(params.len, params.incx)),
make_device_vector_view<T, IndexType>(y.data(), params.len));
} else if (params.incy > 1) {
axpy(handle,
make_host_scalar_view<const T>(&params.alpha),
make_device_vector_view<const T>(x.data(), params.len),
make_strided_device_vector_view<T>(y.data(), params.len, params.incy));
make_device_vector_view<T, IndexType, layout_stride>(
y.data(), make_vector_strided_layout(params.len, params.incy)));
} else {
axpy(handle,
make_host_scalar_view<const T>(&params.alpha),
Expand Down
Loading