Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose linalg::dot in public API #968

Merged
merged 14 commits into from
Nov 10, 2022
20 changes: 19 additions & 1 deletion cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,4 +276,22 @@ auto make_device_vector_view(ElementType* ptr, IndexType n)
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, n};
}

} // end namespace raft
/**
* @brief Create a 1-dim mdspan instance for device pointer, using a strided layout
* @tparam ElementType the data type of the vector elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param[in] ptr on device to wrap
* @param[in] n number of elements in pointer
* @param[in] stride stride between elements
* @return raft::device_vector_view
*/
template <typename ElementType, typename IndexType = int, typename LayoutPolicy = layout_stride>
benfred marked this conversation as resolved.
Show resolved Hide resolved
auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than adding another factory function for a strided vector, why not just allow a strided layout to be configured in the make_device_vector_view and make_host_vector_view?

Right now the make_*_vector_view automatically configures a row-major layout but the layout should really be configurable (and potentially strided, or col major if desired).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated make_device_vector_view to allow strided input here - let me know what you think.

{
vector_extent<IndexType> exts{n};
std::array<IndexType, 1> strides{stride};
auto layout = typename LayoutPolicy::template mapping<vector_extent<IndexType>>{exts, strides};
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, layout};
}
} // end namespace raft
70 changes: 70 additions & 0 deletions cpp/include/raft/linalg/dot.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef __DOT_H
#define __DOT_H

#pragma once

#include <raft/linalg/detail/cublas_wrappers.hpp>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/host_mdspan.hpp>

namespace raft::linalg {

/**
* @brief Computes the dot product of two vectors.
* @tparam InputType1 raft::device_mdspan for the first input vector
* @tparam InputType2 raft::device_mdspan for the second input vector
* @tparam OutputType Either a host_scalar_view or device_scalar_view for the output
* @param[in] handle raft::handle_t
* @param[in] x First input vector
* @param[in] y Second input vector
* @param[out] out The output dot product between the x and y vectors
*/
template <typename InputType1,
typename InputType2,
typename OutputType,
typename = raft::enable_if_input_device_mdspan<InputType1>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I brought this up with the axpy as well, but it seems weird to accept a general mdspan for this when what we are really looking for is a 1d vector. Do you see value in accepting a matrix or dense tensor with 3+ dimensional extents? If not, we should just accept the vector_view directly (which is aliased to be any mdspan with 1d extents.

If we accepted a device_vector_view directly, we wouldn't need the enable_if statements at all. I think we should go ahead and do the same for the axpy to keep things consistent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed - made the changes here so that both axpy and dot take device_vector_view's

typename = raft::enable_if_input_device_mdspan<InputType2>,
typename = raft::enable_if_output_mdspan<OutputType>>
void dot(const raft::handle_t& handle, InputType1 x, InputType2 y, OutputType out)
{
RAFT_EXPECTS(x.size() == y.size(),
"Size mismatch between x and y input vectors in raft::linalg::dot");

// Right now the inputs and outputs need to all have the same value_type (float/double etc).
// Try to output a meaningful compiler error if mismatched types are passed here.
// Note: In the future we could remove this restriction using the cublasDotEx function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just go ahead and wrap the cublasEx functions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created an issue so we can discuss further #977 .

Reading the docs a little closer, and it looks like even w/ cublasDotEx having different dtypes for the input/outputs isn't currently supported: https://docs.nvidia.com/cuda/cublas/index.html#cublas-dotEx - so it won't have much value for the dot API (though I could see a use for it myself with the gemm api w/ implicit and the mixed precision work I was talking about last week)

// in the cublas wrapper call, instead of the cublassdot and cublasddot functions.
static_assert(std::is_same_v<typename InputType1::value_type, typename InputType2::value_type>,
"Both input vectors need to have the same value_type in raft::linalg::dot call");
static_assert(
std::is_same_v<typename InputType1::value_type, typename OutputType::value_type>,
"Input vectors and output scalar need to have the same value_type in raft::linalg::dot call");

RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(),
x.size(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
out.data_handle(),
handle.get_stream()));
}
} // namespace raft::linalg
#endif
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ if(BUILD_TESTS)
test/linalg/cholesky_r1.cu
test/linalg/coalesced_reduction.cu
test/linalg/divide.cu
test/linalg/dot.cu
test/linalg/eig.cu
test/linalg/eig_sel.cu
test/linalg/gemm_layout.cu
Expand Down
11 changes: 0 additions & 11 deletions cpp/test/linalg/axpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,6 @@ __global__ void naiveAxpy(const int n, const T alpha, const T* x, T* y, int incx
if (idx < n) { y[idx * incy] += alpha * x[idx * incx]; }
}

template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_stride>
auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride)
{
vector_extent<IndexType> exts{n};
std::array<IndexType, 1> strides{stride};
auto layout = typename LayoutPolicy::mapping<vector_extent<IndexType>>{exts, strides};
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, layout};
}

template <typename InType, typename IndexType = int, typename OutType = InType>
struct AxpyInputs {
OutType tolerance;
Expand Down
152 changes: 152 additions & 0 deletions cpp/test/linalg/dot.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <raft/linalg/dot.cuh>

#include "../test_utils.h"
#include <gtest/gtest.h>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <rmm/device_scalar.hpp>

namespace raft {
namespace linalg {

// Reference dot implementation.
template <typename T>
__global__ void naiveDot(const int n, const T* x, int incx, const T* y, int incy, T* out)
{
T sum = 0;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
sum += x[i * incx] * y[i * incy];
}
atomicAdd(out, sum);
}

template <typename InType, typename IndexType = int, typename OutType = InType>
struct DotInputs {
OutType tolerance;
IndexType len;
IndexType incx;
IndexType incy;
unsigned long long int seed;
};

template <typename T>
class DotTest : public ::testing::TestWithParam<DotInputs<T>> {
protected:
raft::handle_t handle;
DotInputs<T> params;
rmm::device_scalar<T> output;
rmm::device_scalar<T> refoutput;

public:
DotTest()
: testing::TestWithParam<DotInputs<T>>(),
output(0, handle.get_stream()),
refoutput(0, handle.get_stream())
{
handle.sync_stream();
}

protected:
void SetUp() override
{
params = ::testing::TestWithParam<DotInputs<T>>::GetParam();

cudaStream_t stream = handle.get_stream();

raft::random::RngState r(params.seed);

int x_len = params.len * params.incx;
int y_len = params.len * params.incy;

rmm::device_uvector<T> x(x_len, stream);
rmm::device_uvector<T> y(y_len, stream);
uniform(handle, r, x.data(), x_len, T(-1.0), T(1.0));
uniform(handle, r, y.data(), y_len, T(-1.0), T(1.0));

naiveDot<<<256, 256, 0, stream>>>(
params.len, x.data(), params.incx, y.data(), params.incy, refoutput.data());

auto out_view = make_device_scalar_view<T, int>(output.data());

if ((params.incx > 1) && (params.incy > 1)) {
dot(handle,
make_strided_device_vector_view<const T>(x.data(), params.len, params.incx),
make_strided_device_vector_view<const T>(y.data(), params.len, params.incy),
out_view);
} else if (params.incx > 1) {
dot(handle,
make_strided_device_vector_view<const T>(x.data(), params.len, params.incx),
make_device_vector_view<const T>(y.data(), params.len),
out_view);
} else if (params.incy > 1) {
dot(handle,
make_device_vector_view<const T>(x.data(), params.len),
make_strided_device_vector_view<const T>(y.data(), params.len, params.incy),
out_view);
} else {
dot(handle,
make_device_vector_view<const T>(x.data(), params.len),
make_device_vector_view<const T>(y.data(), params.len),
out_view);
}
handle.sync_stream();
}

void TearDown() override {}
};

const std::vector<DotInputs<float>> inputsf = {
{0.0001f, 1024 * 1024, 1, 1, 1234ULL},
{0.0001f, 16 * 1024 * 1024, 1, 1, 1234ULL},
{0.0001f, 98689, 1, 1, 1234ULL},
{0.0001f, 4 * 1024 * 1024, 1, 1, 1234ULL},
{0.0001f, 1024 * 1024, 4, 1, 1234ULL},
{0.0001f, 1024 * 1024, 1, 3, 1234ULL},
{0.0001f, 1024 * 1024, 4, 3, 1234ULL},
};

const std::vector<DotInputs<double>> inputsd = {
{0.000001f, 1024 * 1024, 1, 1, 1234ULL},
{0.000001f, 16 * 1024 * 1024, 1, 1, 1234ULL},
{0.000001f, 98689, 1, 1, 1234ULL},
{0.000001f, 4 * 1024 * 1024, 1, 1, 1234ULL},
{0.000001f, 1024 * 1024, 4, 1, 1234ULL},
{0.000001f, 1024 * 1024, 1, 3, 1234ULL},
{0.000001f, 1024 * 1024, 4, 3, 1234ULL},
};

typedef DotTest<float> DotTestF;
TEST_P(DotTestF, Result)
{
ASSERT_TRUE(raft::devArrMatch(
refoutput.data(), output.data(), 1, raft::CompareApprox<float>(params.tolerance)));
}

typedef DotTest<double> DotTestD;
TEST_P(DotTestD, Result)
{
ASSERT_TRUE(raft::devArrMatch(
refoutput.data(), output.data(), 1, raft::CompareApprox<double>(params.tolerance)));
}

INSTANTIATE_TEST_SUITE_P(DotTests, DotTestF, ::testing::ValuesIn(inputsf));

INSTANTIATE_TEST_SUITE_P(DotTests, DotTestD, ::testing::ValuesIn(inputsd));

} // end namespace linalg
} // end namespace raft