forked from rapidsai/raft
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Closes rapidsai#805
- Loading branch information
Showing
5 changed files
with
244 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#ifndef __DOT_H | ||
#define __DOT_H | ||
|
||
#pragma once | ||
|
||
#include <raft/linalg/detail/cublas_wrappers.hpp> | ||
|
||
#include <raft/core/device_mdspan.hpp> | ||
#include <raft/core/handle.hpp> | ||
#include <raft/core/host_mdspan.hpp> | ||
|
||
namespace raft::linalg { | ||
|
||
/** | ||
* @brief Computes the dot product of two vectors. | ||
* @tparam InputType1 raft::device_mdspan for the first input vector | ||
* @tparam InputType2 raft::device_mdspan for the second input vector | ||
* @tparam OutputType Either a host_scalar_view or device_scalar_view for the output | ||
* @param[in] handle raft::handle_t | ||
* @param[in] x First input vector | ||
* @param[in] y Second input vector | ||
* @param[out] out The output dot product between the x and y vectors | ||
*/ | ||
template <typename InputType1, | ||
typename InputType2, | ||
typename OutputType, | ||
typename = raft::enable_if_input_device_mdspan<InputType1>, | ||
typename = raft::enable_if_input_device_mdspan<InputType2>, | ||
typename = raft::enable_if_output_mdspan<OutputType>> | ||
void dot(const raft::handle_t& handle, InputType1 x, InputType2 y, OutputType out) | ||
{ | ||
RAFT_EXPECTS(x.size() == y.size(), | ||
"Size mismatch between x and y input vectors in raft::linalg::dot"); | ||
|
||
// Right now the inputs and outputs need to all have the same value_type (float/double etc). | ||
// Try to output a meaningful compiler error if mismatched types are passed here. | ||
// Note: In the future we could remove this restriction using the cublasDotEx function | ||
// in the cublas wrapper call, instead of the cublassdot and cublasddot functions. | ||
static_assert(std::is_same_v<typename InputType1::value_type, typename InputType2::value_type>, | ||
"Both input vectors need to have the same value_type in raft::linalg::dot call"); | ||
static_assert( | ||
std::is_same_v<typename InputType1::value_type, typename OutputType::value_type>, | ||
"Input vectors and output scalar need to have the same value_type in raft::linalg::dot call"); | ||
|
||
RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(), | ||
x.size(), | ||
x.data_handle(), | ||
x.stride(0), | ||
y.data_handle(), | ||
y.stride(0), | ||
out.data_handle(), | ||
handle.get_stream())); | ||
} | ||
} // namespace raft::linalg | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <raft/linalg/dot.cuh> | ||
|
||
#include "../test_utils.h" | ||
#include <gtest/gtest.h> | ||
#include <raft/random/rng.cuh> | ||
#include <raft/util/cuda_utils.cuh> | ||
#include <rmm/device_scalar.hpp> | ||
|
||
namespace raft { | ||
namespace linalg { | ||
|
||
// Reference dot implementation. | ||
template <typename T> | ||
__global__ void naiveDot(const int n, const T* x, int incx, const T* y, int incy, T* out) | ||
{ | ||
T sum = 0; | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { | ||
sum += x[i * incx] * y[i * incy]; | ||
} | ||
atomicAdd(out, sum); | ||
} | ||
|
||
template <typename InType, typename IndexType = int, typename OutType = InType> | ||
struct DotInputs { | ||
OutType tolerance; | ||
IndexType len; | ||
IndexType incx; | ||
IndexType incy; | ||
unsigned long long int seed; | ||
}; | ||
|
||
template <typename T> | ||
class DotTest : public ::testing::TestWithParam<DotInputs<T>> { | ||
protected: | ||
raft::handle_t handle; | ||
DotInputs<T> params; | ||
rmm::device_scalar<T> output; | ||
rmm::device_scalar<T> refoutput; | ||
|
||
public: | ||
DotTest() | ||
: testing::TestWithParam<DotInputs<T>>(), | ||
output(0, handle.get_stream()), | ||
refoutput(0, handle.get_stream()) | ||
{ | ||
handle.sync_stream(); | ||
} | ||
|
||
protected: | ||
void SetUp() override | ||
{ | ||
params = ::testing::TestWithParam<DotInputs<T>>::GetParam(); | ||
|
||
cudaStream_t stream = handle.get_stream(); | ||
|
||
raft::random::RngState r(params.seed); | ||
|
||
int x_len = params.len * params.incx; | ||
int y_len = params.len * params.incy; | ||
|
||
rmm::device_uvector<T> x(x_len, stream); | ||
rmm::device_uvector<T> y(y_len, stream); | ||
uniform(handle, r, x.data(), x_len, T(-1.0), T(1.0)); | ||
uniform(handle, r, y.data(), y_len, T(-1.0), T(1.0)); | ||
|
||
naiveDot<<<256, 256, 0, stream>>>( | ||
params.len, x.data(), params.incx, y.data(), params.incy, refoutput.data()); | ||
|
||
auto out_view = make_device_scalar_view<T, int>(output.data()); | ||
|
||
if ((params.incx > 1) && (params.incy > 1)) { | ||
dot(handle, | ||
make_strided_device_vector_view<const T>(x.data(), params.len, params.incx), | ||
make_strided_device_vector_view<const T>(y.data(), params.len, params.incy), | ||
out_view); | ||
} else if (params.incx > 1) { | ||
dot(handle, | ||
make_strided_device_vector_view<const T>(x.data(), params.len, params.incx), | ||
make_device_vector_view<const T>(y.data(), params.len), | ||
out_view); | ||
} else if (params.incy > 1) { | ||
dot(handle, | ||
make_device_vector_view<const T>(x.data(), params.len), | ||
make_strided_device_vector_view<const T>(y.data(), params.len, params.incy), | ||
out_view); | ||
} else { | ||
dot(handle, | ||
make_device_vector_view<const T>(x.data(), params.len), | ||
make_device_vector_view<const T>(y.data(), params.len), | ||
out_view); | ||
} | ||
handle.sync_stream(); | ||
} | ||
|
||
void TearDown() override {} | ||
}; | ||
|
||
const std::vector<DotInputs<float>> inputsf = { | ||
{0.0001f, 1024 * 1024, 1, 1, 1234ULL}, | ||
{0.0001f, 16 * 1024 * 1024, 1, 1, 1234ULL}, | ||
{0.0001f, 98689, 1, 1, 1234ULL}, | ||
{0.0001f, 4 * 1024 * 1024, 1, 1, 1234ULL}, | ||
{0.0001f, 1024 * 1024, 4, 1, 1234ULL}, | ||
{0.0001f, 1024 * 1024, 1, 3, 1234ULL}, | ||
{0.0001f, 1024 * 1024, 4, 3, 1234ULL}, | ||
}; | ||
|
||
const std::vector<DotInputs<double>> inputsd = { | ||
{0.000001f, 1024 * 1024, 1, 1, 1234ULL}, | ||
{0.000001f, 16 * 1024 * 1024, 1, 1, 1234ULL}, | ||
{0.000001f, 98689, 1, 1, 1234ULL}, | ||
{0.000001f, 4 * 1024 * 1024, 1, 1, 1234ULL}, | ||
{0.000001f, 1024 * 1024, 4, 1, 1234ULL}, | ||
{0.000001f, 1024 * 1024, 1, 3, 1234ULL}, | ||
{0.000001f, 1024 * 1024, 4, 3, 1234ULL}, | ||
}; | ||
|
||
typedef DotTest<float> DotTestF; | ||
TEST_P(DotTestF, Result) | ||
{ | ||
ASSERT_TRUE(raft::devArrMatch( | ||
refoutput.data(), output.data(), 1, raft::CompareApprox<float>(params.tolerance))); | ||
} | ||
|
||
typedef DotTest<double> DotTestD; | ||
TEST_P(DotTestD, Result) | ||
{ | ||
ASSERT_TRUE(raft::devArrMatch( | ||
refoutput.data(), output.data(), 1, raft::CompareApprox<double>(params.tolerance))); | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P(DotTests, DotTestF, ::testing::ValuesIn(inputsf)); | ||
|
||
INSTANTIATE_TEST_SUITE_P(DotTests, DotTestD, ::testing::ValuesIn(inputsd)); | ||
|
||
} // end namespace linalg | ||
} // end namespace raft |