Skip to content

Commit

Permalink
Expose linalg::dot in public API
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Oct 31, 2022
1 parent 0df8493 commit 01dd067
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 12 deletions.
22 changes: 21 additions & 1 deletion cpp/include/raft/core/device_mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,4 +276,24 @@ auto make_device_vector_view(ElementType* ptr, IndexType n)
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, n};
}

} // end namespace raft
/**
* @brief Create a 1-dim mdspan instance for device pointer, using a strided layout
* @tparam ElementType the data type of the vector elements
* @tparam IndexType the index type of the extents
* @tparam LayoutPolicy policy for strides and layout ordering
* @param[in] ptr on device to wrap
* @param[in] n number of elements in pointer
* @param[in] stride stride between elements
* @return raft::device_vector_view
*/
template <typename ElementType,
typename IndexType = int,
typename LayoutPolicy = layout_stride>
auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride)
{
vector_extent<IndexType> exts{n};
std::array<IndexType, 1> strides{stride};
auto layout = typename LayoutPolicy::template mapping<vector_extent<IndexType>>{exts, strides};
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, layout};
}
} // end namespace raft
70 changes: 70 additions & 0 deletions cpp/include/raft/linalg/dot.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef __DOT_H
#define __DOT_H

#pragma once

#include <raft/linalg/detail/cublas_wrappers.hpp>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/host_mdspan.hpp>

namespace raft::linalg {

/**
* @brief Computes the dot product of two vectors.
* @tparam InputType1 raft::device_mdspan for the first input vector
* @tparam InputType2 raft::device_mdspan for the second input vector
* @tparam OutputType Either a host_scalar_view or device_scalar_view for the output
* @param[in] handle raft::handle_t
* @param[in] x First input vector
* @param[in] y Second input vector
* @param[out] out The output dot product between the x and y vectors
*/
template <typename InputType1,
typename InputType2,
typename OutputType,
typename = raft::enable_if_input_device_mdspan<InputType1>,
typename = raft::enable_if_input_device_mdspan<InputType2>,
typename = raft::enable_if_output_mdspan<OutputType>>
void dot(const raft::handle_t& handle, InputType1 x, InputType2 y, OutputType out)
{
RAFT_EXPECTS(x.size() == y.size(),
"Size mismatch between x and y input vectors in raft::linalg::dot");

// Right now the inputs and outputs need to all have the same value_type (float/double etc).
// Try to output a meaningful compiler error if mismatched types are passed here.
// Note: In the future we could remove this restriction using the cublasDotEx function
// in the cublas wrapper call, instead of the cublassdot and cublasddot functions.
static_assert(std::is_same_v<typename InputType1::value_type, typename InputType2::value_type>,
"Both input vectors need to have the same value_type in raft::linalg::dot call");
static_assert(
std::is_same_v<typename InputType1::value_type, typename OutputType::value_type>,
"Input vectors and output scalar need to have the same value_type in raft::linalg::dot call");

RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(),
x.size(),
x.data_handle(),
x.stride(0),
y.data_handle(),
y.stride(0),
out.data_handle(),
handle.get_stream()));
}
} // namespace raft::linalg
#endif
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ if(BUILD_TESTS)
test/linalg/cholesky_r1.cu
test/linalg/coalesced_reduction.cu
test/linalg/divide.cu
test/linalg/dot.cu
test/linalg/eig.cu
test/linalg/eig_sel.cu
test/linalg/gemm_layout.cu
Expand Down
11 changes: 0 additions & 11 deletions cpp/test/linalg/axpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,6 @@ __global__ void naiveAxpy(const int n, const T alpha, const T* x, T* y, int incx
if (idx < n) { y[idx * incy] += alpha * x[idx * incx]; }
}

template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_stride>
auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride)
{
vector_extent<IndexType> exts{n};
std::array<IndexType, 1> strides{stride};
auto layout = typename LayoutPolicy::mapping<vector_extent<IndexType>>{exts, strides};
return device_vector_view<ElementType, IndexType, LayoutPolicy>{ptr, layout};
}

template <typename InType, typename IndexType = int, typename OutType = InType>
struct AxpyInputs {
OutType tolerance;
Expand Down
152 changes: 152 additions & 0 deletions cpp/test/linalg/dot.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <raft/linalg/dot.cuh>

#include "../test_utils.h"
#include <gtest/gtest.h>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <rmm/device_scalar.hpp>

namespace raft {
namespace linalg {

// Reference dot implementation.
template <typename T>
__global__ void naiveDot(const int n, const T* x, int incx, const T* y, int incy, T* out)
{
T sum = 0;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
sum += x[i * incx] * y[i * incy];
}
atomicAdd(out, sum);
}

template <typename InType, typename IndexType = int, typename OutType = InType>
struct DotInputs {
OutType tolerance;
IndexType len;
IndexType incx;
IndexType incy;
unsigned long long int seed;
};

template <typename T>
class DotTest : public ::testing::TestWithParam<DotInputs<T>> {
protected:
raft::handle_t handle;
DotInputs<T> params;
rmm::device_scalar<T> output;
rmm::device_scalar<T> refoutput;

public:
DotTest()
: testing::TestWithParam<DotInputs<T>>(),
output(0, handle.get_stream()),
refoutput(0, handle.get_stream())
{
handle.sync_stream();
}

protected:
void SetUp() override
{
params = ::testing::TestWithParam<DotInputs<T>>::GetParam();

cudaStream_t stream = handle.get_stream();

raft::random::RngState r(params.seed);

int x_len = params.len * params.incx;
int y_len = params.len * params.incy;

rmm::device_uvector<T> x(x_len, stream);
rmm::device_uvector<T> y(y_len, stream);
uniform(handle, r, x.data(), x_len, T(-1.0), T(1.0));
uniform(handle, r, y.data(), y_len, T(-1.0), T(1.0));

naiveDot<<<256, 256, 0, stream>>>(
params.len, x.data(), params.incx, y.data(), params.incy, refoutput.data());

auto out_view = make_device_scalar_view<T, int>(output.data());

if ((params.incx > 1) && (params.incy > 1)) {
dot(handle,
make_strided_device_vector_view<const T>(x.data(), params.len, params.incx),
make_strided_device_vector_view<const T>(y.data(), params.len, params.incy),
out_view);
} else if (params.incx > 1) {
dot(handle,
make_strided_device_vector_view<const T>(x.data(), params.len, params.incx),
make_device_vector_view<const T>(y.data(), params.len),
out_view);
} else if (params.incy > 1) {
dot(handle,
make_device_vector_view<const T>(x.data(), params.len),
make_strided_device_vector_view<const T>(y.data(), params.len, params.incy),
out_view);
} else {
dot(handle,
make_device_vector_view<const T>(x.data(), params.len),
make_device_vector_view<const T>(y.data(), params.len),
out_view);
}
handle.sync_stream();
}

void TearDown() override {}
};

const std::vector<DotInputs<float>> inputsf = {
{0.0001f, 1024 * 1024, 1, 1, 1234ULL},
{0.0001f, 16 * 1024 * 1024, 1, 1, 1234ULL},
{0.0001f, 98689, 1, 1, 1234ULL},
{0.0001f, 4 * 1024 * 1024, 1, 1, 1234ULL},
{0.0001f, 1024 * 1024, 4, 1, 1234ULL},
{0.0001f, 1024 * 1024, 1, 3, 1234ULL},
{0.0001f, 1024 * 1024, 4, 3, 1234ULL},
};

const std::vector<DotInputs<double>> inputsd = {
{0.000001f, 1024 * 1024, 1, 1, 1234ULL},
{0.000001f, 16 * 1024 * 1024, 1, 1, 1234ULL},
{0.000001f, 98689, 1, 1, 1234ULL},
{0.000001f, 4 * 1024 * 1024, 1, 1, 1234ULL},
{0.000001f, 1024 * 1024, 4, 1, 1234ULL},
{0.000001f, 1024 * 1024, 1, 3, 1234ULL},
{0.000001f, 1024 * 1024, 4, 3, 1234ULL},
};

typedef DotTest<float> DotTestF;
TEST_P(DotTestF, Result)
{
ASSERT_TRUE(raft::devArrMatch(
refoutput.data(), output.data(), 1, raft::CompareApprox<float>(params.tolerance)));
}

typedef DotTest<double> DotTestD;
TEST_P(DotTestD, Result)
{
ASSERT_TRUE(raft::devArrMatch(
refoutput.data(), output.data(), 1, raft::CompareApprox<double>(params.tolerance)));
}

INSTANTIATE_TEST_SUITE_P(DotTests, DotTestF, ::testing::ValuesIn(inputsf));

INSTANTIATE_TEST_SUITE_P(DotTests, DotTestD, ::testing::ValuesIn(inputsd));

} // end namespace linalg
} // end namespace raft

0 comments on commit 01dd067

Please sign in to comment.