Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unittest for linalg::mean_squared_error #961

Merged
merged 1 commit into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/include/raft/linalg/mean_squared_error.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ namespace linalg {
* @param weight weight to apply to every term in the mean squared error calculation
* @param stream cuda-stream where to launch this kernel
*/
template <typename in_t, typename out_t, typename idx_t>
template <typename in_t, typename out_t, typename idx_t = size_t>
void meanSquaredError(
out_t* out, const in_t* A, const in_t* B, size_t len, in_t weight, cudaStream_t stream)
out_t* out, const in_t* A, const in_t* B, idx_t len, in_t weight, cudaStream_t stream)
{
detail::meanSquaredError(out, A, B, len, weight, stream);
}
Expand Down
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ if(BUILD_TESTS)
test/linalg/map.cu
test/linalg/map_then_reduce.cu
test/linalg/matrix_vector_op.cu
test/linalg/mean_squared_error.cu
test/linalg/multiply.cu
test/linalg/norm.cu
test/linalg/power.cu
Expand Down
131 changes: 131 additions & 0 deletions cpp/test/linalg/mean_squared_error.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <raft/linalg/mean_squared_error.cuh>

#include "../test_utils.h"
#include <gtest/gtest.h>
#include <raft/random/rng.cuh>
#include <raft/util/cuda_utils.cuh>
#include <rmm/device_scalar.hpp>

namespace raft {
namespace linalg {

// reference MSE calculation
template <typename T>
__global__ void naiveMeanSquaredError(const int n, const T* a, const T* b, T weight, T* out)
{
T err = 0;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
T diff = a[i] - b[i];
err += weight * diff * diff / n;
}
atomicAdd(out, err);
}

template <typename T, typename IndexType = std::uint32_t>
struct MeanSquaredErrorInputs {
T tolerance;
IndexType len;
T weight;
unsigned long long int seed;
};

template <typename T>
class MeanSquaredErrorTest : public ::testing::TestWithParam<MeanSquaredErrorInputs<T>> {
protected:
MeanSquaredErrorInputs<T> params;

raft::handle_t handle;
rmm::device_scalar<T> output;
rmm::device_scalar<T> refoutput;

public:
MeanSquaredErrorTest()
: testing::TestWithParam<MeanSquaredErrorInputs<T>>(),
output(0, handle.get_stream()),
refoutput(0, handle.get_stream())
{
handle.sync_stream();
}

protected:
void SetUp() override
{
params = ::testing::TestWithParam<MeanSquaredErrorInputs<T>>::GetParam();

cudaStream_t stream = handle.get_stream();

raft::random::RngState r(params.seed);

rmm::device_uvector<T> a(params.len, stream);
rmm::device_uvector<T> b(params.len, stream);
uniform(handle, r, a.data(), params.len, T(-1.0), T(1.0));
uniform(handle, r, b.data(), params.len, T(-1.0), T(1.0));
handle.sync_stream();

mean_squared_error<T, std::uint32_t, T>(handle,
make_device_vector_view<const T>(a.data(), params.len),
make_device_vector_view<const T>(b.data(), params.len),
make_device_scalar_view<T>(output.data()),
params.weight);

naiveMeanSquaredError<<<256, 256, 0, stream>>>(
params.len, a.data(), b.data(), params.weight, refoutput.data());
handle.sync_stream();
}

void TearDown() override {}
};

const std::vector<MeanSquaredErrorInputs<float>> inputsf = {
{0.0001f, 1024 * 1024, 1.0, 1234ULL},
{0.0001f, 4 * 1024 * 1024, 8.0, 1234ULL},
{0.0001f, 16 * 1024 * 1024, 24.0, 1234ULL},
{0.0001f, 98689, 1.0, 1234ULL},
};

const std::vector<MeanSquaredErrorInputs<double>> inputsd = {
{0.0001f, 1024 * 1024, 1.0, 1234ULL},
{0.0001f, 4 * 1024 * 1024, 8.0, 1234ULL},
{0.0001f, 16 * 1024 * 1024, 24.0, 1234ULL},
{0.0001f, 98689, 1.0, 1234ULL},
};

typedef MeanSquaredErrorTest<float> MeanSquaredErrorTestF;
TEST_P(MeanSquaredErrorTestF, Result)
{
ASSERT_TRUE(raft::devArrMatch(
refoutput.data(), output.data(), 1, raft::CompareApprox<float>(params.tolerance)));
}

typedef MeanSquaredErrorTest<double> MeanSquaredErrorTestD;
TEST_P(MeanSquaredErrorTestD, Result)
{
ASSERT_TRUE(raft::devArrMatch(
refoutput.data(), output.data(), 1, raft::CompareApprox<double>(params.tolerance)));
}

INSTANTIATE_TEST_SUITE_P(MeanSquaredErrorTests,
MeanSquaredErrorTestF,
::testing::ValuesIn(inputsf));

INSTANTIATE_TEST_SUITE_P(MeanSquaredErrorTests,
MeanSquaredErrorTestD,
::testing::ValuesIn(inputsd));

} // end namespace linalg
} // end namespace raft