diff --git a/cpp/include/raft/linalg/mean_squared_error.cuh b/cpp/include/raft/linalg/mean_squared_error.cuh index 298f32339c..a3360ae35a 100644 --- a/cpp/include/raft/linalg/mean_squared_error.cuh +++ b/cpp/include/raft/linalg/mean_squared_error.cuh @@ -35,9 +35,9 @@ namespace linalg { * @param weight weight to apply to every term in the mean squared error calculation * @param stream cuda-stream where to launch this kernel */ -template +template void meanSquaredError( - out_t* out, const in_t* A, const in_t* B, size_t len, in_t weight, cudaStream_t stream) + out_t* out, const in_t* A, const in_t* B, idx_t len, in_t weight, cudaStream_t stream) { detail::meanSquaredError(out, A, B, len, weight, stream); } diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 0d5af9be5c..8f04537b6b 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -140,6 +140,7 @@ if(BUILD_TESTS) test/linalg/map.cu test/linalg/map_then_reduce.cu test/linalg/matrix_vector_op.cu + test/linalg/mean_squared_error.cu test/linalg/multiply.cu test/linalg/norm.cu test/linalg/power.cu diff --git a/cpp/test/linalg/mean_squared_error.cu b/cpp/test/linalg/mean_squared_error.cu new file mode 100644 index 0000000000..795f831417 --- /dev/null +++ b/cpp/test/linalg/mean_squared_error.cu @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "../test_utils.h" +#include +#include +#include +#include + +namespace raft { +namespace linalg { + +// reference MSE calculation +template +__global__ void naiveMeanSquaredError(const int n, const T* a, const T* b, T weight, T* out) +{ + T err = 0; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { + T diff = a[i] - b[i]; + err += weight * diff * diff / n; + } + atomicAdd(out, err); +} + +template +struct MeanSquaredErrorInputs { + T tolerance; + IndexType len; + T weight; + unsigned long long int seed; +}; + +template +class MeanSquaredErrorTest : public ::testing::TestWithParam> { + protected: + MeanSquaredErrorInputs params; + + raft::handle_t handle; + rmm::device_scalar output; + rmm::device_scalar refoutput; + + public: + MeanSquaredErrorTest() + : testing::TestWithParam>(), + output(0, handle.get_stream()), + refoutput(0, handle.get_stream()) + { + handle.sync_stream(); + } + + protected: + void SetUp() override + { + params = ::testing::TestWithParam>::GetParam(); + + cudaStream_t stream = handle.get_stream(); + + raft::random::RngState r(params.seed); + + rmm::device_uvector a(params.len, stream); + rmm::device_uvector b(params.len, stream); + uniform(handle, r, a.data(), params.len, T(-1.0), T(1.0)); + uniform(handle, r, b.data(), params.len, T(-1.0), T(1.0)); + handle.sync_stream(); + + mean_squared_error(handle, + make_device_vector_view(a.data(), params.len), + make_device_vector_view(b.data(), params.len), + make_device_scalar_view(output.data()), + params.weight); + + naiveMeanSquaredError<<<256, 256, 0, stream>>>( + params.len, a.data(), b.data(), params.weight, refoutput.data()); + handle.sync_stream(); + } + + void TearDown() override {} +}; + +const std::vector> inputsf = { + {0.0001f, 1024 * 1024, 1.0, 1234ULL}, + {0.0001f, 4 * 1024 * 1024, 8.0, 1234ULL}, + {0.0001f, 16 * 1024 * 1024, 24.0, 1234ULL}, + {0.0001f, 98689, 1.0, 1234ULL}, +}; + +const std::vector> inputsd = { + {0.0001f, 1024 * 1024, 1.0, 1234ULL}, + {0.0001f, 4 * 1024 * 1024, 8.0, 1234ULL}, + {0.0001f, 16 * 1024 * 1024, 24.0, 1234ULL}, + {0.0001f, 98689, 1.0, 1234ULL}, +}; + +typedef MeanSquaredErrorTest MeanSquaredErrorTestF; +TEST_P(MeanSquaredErrorTestF, Result) +{ + ASSERT_TRUE(raft::devArrMatch( + refoutput.data(), output.data(), 1, raft::CompareApprox(params.tolerance))); +} + +typedef MeanSquaredErrorTest MeanSquaredErrorTestD; +TEST_P(MeanSquaredErrorTestD, Result) +{ + ASSERT_TRUE(raft::devArrMatch( + refoutput.data(), output.data(), 1, raft::CompareApprox(params.tolerance))); +} + +INSTANTIATE_TEST_SUITE_P(MeanSquaredErrorTests, + MeanSquaredErrorTestF, + ::testing::ValuesIn(inputsf)); + +INSTANTIATE_TEST_SUITE_P(MeanSquaredErrorTests, + MeanSquaredErrorTestD, + ::testing::ValuesIn(inputsd)); + +} // end namespace linalg +} // end namespace raft