diff --git a/cpp/include/raft/stats/weighted_mean.cuh b/cpp/include/raft/stats/weighted_mean.cuh index 0e8338fe84..65d1b2c35f 100644 --- a/cpp/include/raft/stats/weighted_mean.cuh +++ b/cpp/include/raft/stats/weighted_mean.cuh @@ -19,6 +19,7 @@ #pragma once +#include #include namespace raft { @@ -91,6 +92,90 @@ void colWeightedMean( { weightedMean(mu, data, weights, D, N, true, false, stream); } + +/** + * @brief Compute the weighted mean of the input matrix with a + * vector of weights, along rows or along columns + * + * @tparam value_t the data type + * @tparam idx_t Integer type used to for addressing + * @tparam layout_t Layout type of the input matrix. + * @param[in] handle the raft handle + * @param[in] data the input matrix of size nrows * ncols + * @param[in] weights weight of size ncols if along_row is true, else of size nrows + * @param[out] mu the output mean vector of size nrows if along_row is true, else of size ncols + * @param[in] along_rows whether to reduce along rows or columns + */ +template +void weighted_mean(const raft::handle_t& handle, + raft::device_matrix_view data, + raft::device_vector_view weights, + raft::device_vector_view mu, + bool along_rows) +{ + constexpr bool is_row_major = std::is_same_v; + constexpr bool is_col_major = std::is_same_v; + static_assert(is_row_major || is_col_major, + "weighted_mean: Layout must be either " + "raft::row_major or raft::col_major (or one of their aliases)"); + auto mean_vec_size = along_rows ? data.extent(0) : data.extent(1); + auto weight_size = along_rows ? data.extent(1) : data.extent(0); + + RAFT_EXPECTS(weights.extent(0) == weight_size, + "Size mismatch between weights and expected weight_size"); + RAFT_EXPECTS(mu.extent(0) == mean_vec_size, "Size mismatch betwen mu and expected mean_vec_size"); + + detail::weightedMean(mu.data_handle(), + data.data_handle(), + weights.data_handle(), + data.extent(1), + data.extent(0), + is_row_major, + along_rows, + handle.get_stream()); +} + +/** + * @brief Compute the row-wise weighted mean of the input matrix with a + * vector of column weights + * + * @tparam value_t the data type + * @tparam idx_t Integer type used to for addressing + * @tparam layout_t Layout type of the input matrix. + * @param[in] handle the raft handle + * @param[in] data the input matrix of size nrows * ncols + * @param[in] weights weight vector of size ncols + * @param[out] mu the output mean vector of size nrows + */ +template +void row_weighted_mean(const raft::handle_t& handle, + raft::device_matrix_view data, + raft::device_vector_view weights, + raft::device_vector_view mu) +{ + weighted_mean(handle, data, weights, mu, true); +} + +/** + * @brief Compute the column-wise weighted mean of the input matrix with a + * vector of row weights + * + * @tparam value_t the data type + * @tparam idx_t Integer type used to for addressing + * @tparam layout_t Layout type of the input matrix. + * @param[in] handle the raft handle + * @param[in] data the input matrix of size nrows * ncols + * @param[in] weights weight vector of size nrows + * @param[out] mu the output mean vector of size ncols + */ +template +void col_weighted_mean(const raft::handle_t& handle, + raft::device_matrix_view data, + raft::device_vector_view weights, + raft::device_vector_view mu) +{ + weighted_mean(handle, data, weights, mu, false); +} }; // end namespace stats }; // end namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 0c9b721294..a7f203ba6a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -237,6 +237,7 @@ if(BUILD_TESTS) ConfigureTest(NAME STATS_TEST PATH + test/stats/accuracy.cu test/stats/adjusted_rand_index.cu test/stats/completeness_score.cu test/stats/contingencyMatrix.cu @@ -252,7 +253,9 @@ if(BUILD_TESTS) test/stats/mean_center.cu test/stats/minmax.cu test/stats/mutual_info_score.cu + test/stats/r2_score.cu test/stats/rand_index.cu + test/stats/regression_metrics.cu test/stats/silhouette_score.cu test/stats/stddev.cu test/stats/sum.cu diff --git a/cpp/test/stats/accuracy.cu b/cpp/test/stats/accuracy.cu new file mode 100644 index 0000000000..192c187794 --- /dev/null +++ b/cpp/test/stats/accuracy.cu @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace stats { + +template +struct AccuracyInputs { + T tolerance; + int nrows; + unsigned long long int seed; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const AccuracyInputs& dims) +{ + return os; +} + +template +class AccuracyTest : public ::testing::TestWithParam> { + protected: + AccuracyTest() : stream(handle.get_stream()) {} + + void SetUp() override + { + params = ::testing::TestWithParam>::GetParam(); + raft::random::RngState r(params.seed); + rmm::device_uvector predictions(params.nrows, stream); + rmm::device_uvector ref_predictions(params.nrows, stream); + uniformInt(handle, r, predictions.data(), params.nrows, 0, 10); + uniformInt(handle, r, ref_predictions.data(), params.nrows, 0, 10); + + actualVal = + accuracy(handle, + raft::make_device_vector_view(predictions.data(), params.nrows), + raft::make_device_vector_view(ref_predictions.data(), params.nrows)); + expectedVal = T(0); + std::vector h_predictions(params.nrows, 0); + std::vector h_ref_predictions(params.nrows, 0); + raft::update_host(h_predictions.data(), predictions.data(), params.nrows, stream); + raft::update_host(h_ref_predictions.data(), ref_predictions.data(), params.nrows, stream); + + unsigned long long correctly_predicted = 0ULL; + for (int i = 0; i < params.nrows; ++i) { + correctly_predicted += (h_predictions[i] - h_ref_predictions[i]) == 0; + } + expectedVal = correctly_predicted * 1.0f / params.nrows; + raft::interruptible::synchronize(stream); + } + + protected: + AccuracyInputs params; + raft::handle_t handle; + cudaStream_t stream = 0; + T expectedVal, actualVal; +}; + +const std::vector> inputsf = { + {0.001f, 30, 1234ULL}, {0.001f, 100, 1234ULL}, {0.001f, 1000, 1234ULL}}; +typedef AccuracyTest AccuracyTestF; +TEST_P(AccuracyTestF, Result) +{ + auto eq = raft::CompareApprox(params.tolerance); + ASSERT_TRUE(match(expectedVal, actualVal, eq)); +} +INSTANTIATE_TEST_CASE_P(AccuracyTests, AccuracyTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 30, 1234ULL}, {0.001, 100, 1234ULL}, {0.001, 1000, 1234ULL}}; +typedef AccuracyTest AccuracyTestD; +TEST_P(AccuracyTestD, Result) +{ + auto eq = raft::CompareApprox(params.tolerance); + ASSERT_TRUE(match(expectedVal, actualVal, eq)); +} +INSTANTIATE_TEST_CASE_P(AccuracyTests, AccuracyTestD, ::testing::ValuesIn(inputsd)); + +} // end namespace stats +} // end namespace raft diff --git a/cpp/test/stats/r2_score.cu b/cpp/test/stats/r2_score.cu new file mode 100644 index 0000000000..d77daacb04 --- /dev/null +++ b/cpp/test/stats/r2_score.cu @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace stats { + +template +struct R2_scoreInputs { + T tolerance; + int nrows; + unsigned long long int seed; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const R2_scoreInputs& dims) +{ + return os; +} + +template +class R2_scoreTest : public ::testing::TestWithParam> { + protected: + R2_scoreTest() : stream(handle.get_stream()) {} + + void SetUp() override + { + params = ::testing::TestWithParam>::GetParam(); + raft::random::RngState r(params.seed); + rmm::device_uvector y(params.nrows, stream); + rmm::device_uvector y_hat(params.nrows, stream); + uniform(handle, r, y.data(), params.nrows, (T)-1.0, (T)1.0); + uniform(handle, r, y_hat.data(), params.nrows, (T)-1.0, (T)1.0); + + actualVal = r2_score(handle, + raft::make_device_vector_view(y.data(), params.nrows), + raft::make_device_vector_view(y_hat.data(), params.nrows)); + expectedVal = T(0); + std::vector h_y(params.nrows, 0); + std::vector h_y_hat(params.nrows, 0); + raft::update_host(h_y.data(), y.data(), params.nrows, stream); + raft::update_host(h_y_hat.data(), y_hat.data(), params.nrows, stream); + T mean = T(0); + for (int i = 0; i < params.nrows; ++i) { + mean += h_y[i]; + } + mean /= params.nrows; + + std::vector sse_arr(params.nrows, 0); + std::vector ssto_arr(params.nrows, 0); + T sse = T(0); + T ssto = T(0); + for (int i = 0; i < params.nrows; ++i) { + sse += (h_y[i] - h_y_hat[i]) * (h_y[i] - h_y_hat[i]); + ssto += (h_y[i] - mean) * (h_y[i] - mean); + } + expectedVal = 1.0 - sse / ssto; + raft::interruptible::synchronize(stream); + } + + protected: + R2_scoreInputs params; + raft::handle_t handle; + cudaStream_t stream = 0; + T expectedVal, actualVal; +}; + +const std::vector> inputsf = { + {0.001f, 30, 1234ULL}, {0.001f, 100, 1234ULL}, {0.001f, 1000, 1234ULL}}; +typedef R2_scoreTest R2_scoreTestF; +TEST_P(R2_scoreTestF, Result) +{ + auto eq = raft::CompareApprox(params.tolerance); + ASSERT_TRUE(match(expectedVal, actualVal, eq)); +} +INSTANTIATE_TEST_CASE_P(R2_scoreTests, R2_scoreTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 30, 1234ULL}, {0.001, 100, 1234ULL}, {0.001, 1000, 1234ULL}}; +typedef R2_scoreTest R2_scoreTestD; +TEST_P(R2_scoreTestD, Result) +{ + auto eq = raft::CompareApprox(params.tolerance); + ASSERT_TRUE(match(expectedVal, actualVal, eq)); +} +INSTANTIATE_TEST_CASE_P(R2_scoreTests, R2_scoreTestD, ::testing::ValuesIn(inputsd)); + +} // end namespace stats +} // end namespace raft diff --git a/cpp/test/stats/regression_metrics.cu b/cpp/test/stats/regression_metrics.cu new file mode 100644 index 0000000000..effc3d04dd --- /dev/null +++ b/cpp/test/stats/regression_metrics.cu @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace stats { + +template +struct RegressionInputs { + T tolerance; + int len; + unsigned long long int seed; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const RegressionInputs& dims) +{ + return os; +} + +template +void naive_reg_metrics(std::vector& predictions, + std::vector& ref_predictions, + double& mean_abs_error, + double& mean_squared_error, + double& median_abs_error) +{ + auto len = predictions.size(); + double abs_diff = 0; + double sq_diff = 0; + std::vector abs_errors(len); + for (std::size_t i = 0; i < len; ++i) { + auto diff = predictions[i] - ref_predictions[i]; + abs_diff += abs(diff); + sq_diff += diff * diff; + abs_errors[i] = abs(diff); + } + mean_abs_error = abs_diff / len; + mean_squared_error = sq_diff / len; + + std::sort(abs_errors.begin(), abs_errors.end()); + auto middle = len / 2; + if (len % 2 == 1) { + median_abs_error = abs_errors[middle]; + } else { + median_abs_error = (abs_errors[middle] + abs_errors[middle - 1]) / 2; + } +} + +template +class RegressionTest : public ::testing::TestWithParam> { + protected: + RegressionTest() : stream(handle.get_stream()) {} + + void SetUp() override + { + params = ::testing::TestWithParam>::GetParam(); + raft::random::RngState r(params.seed); + rmm::device_uvector predictions(params.len, stream); + rmm::device_uvector ref_predictions(params.len, stream); + uniform(handle, r, predictions.data(), params.len, T(-10.0), T(10.0)); + uniform(handle, r, ref_predictions.data(), params.len, T(-10.0), T(10.0)); + + regression_metrics(handle, + raft::make_device_vector_view(predictions.data(), params.len), + raft::make_device_vector_view(ref_predictions.data(), params.len), + raft::make_host_scalar_view(&mean_abs_error), + raft::make_host_scalar_view(&mean_squared_error), + raft::make_host_scalar_view(&median_abs_error)); + std::vector h_predictions(params.len, 0); + std::vector h_ref_predictions(params.len, 0); + raft::update_host(h_predictions.data(), predictions.data(), params.len, stream); + raft::update_host(h_ref_predictions.data(), ref_predictions.data(), params.len, stream); + + naive_reg_metrics(h_predictions, + h_ref_predictions, + ref_mean_abs_error, + ref_mean_squared_error, + ref_median_abs_error); + raft::interruptible::synchronize(stream); + } + + protected: + RegressionInputs params; + raft::handle_t handle; + cudaStream_t stream = 0; + double mean_abs_error = 0; + double mean_squared_error = 0; + double median_abs_error = 0; + double ref_mean_abs_error = 0; + double ref_mean_squared_error = 0; + double ref_median_abs_error = 0; +}; + +const std::vector> inputsf = { + {0.001f, 30, 1234ULL}, {0.001f, 100, 1234ULL}, {0.001f, 4000, 1234ULL}}; +typedef RegressionTest RegressionTestF; +TEST_P(RegressionTestF, Result) +{ + auto eq = raft::CompareApprox(params.tolerance); + ASSERT_TRUE(match(ref_mean_abs_error, mean_abs_error, eq)); + ASSERT_TRUE(match(ref_mean_squared_error, mean_squared_error, eq)); + ASSERT_TRUE(match(ref_median_abs_error, median_abs_error, eq)); +} +INSTANTIATE_TEST_CASE_P(RegressionTests, RegressionTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 30, 1234ULL}, {0.001, 100, 1234ULL}, {0.001, 4000, 1234ULL}}; +typedef RegressionTest RegressionTestD; +TEST_P(RegressionTestD, Result) +{ + auto eq = raft::CompareApprox(params.tolerance); + ASSERT_TRUE(match(ref_mean_abs_error, mean_abs_error, eq)); + ASSERT_TRUE(match(ref_mean_squared_error, mean_squared_error, eq)); + ASSERT_TRUE(match(ref_median_abs_error, median_abs_error, eq)); +} +INSTANTIATE_TEST_CASE_P(RegressionTests, RegressionTestD, ::testing::ValuesIn(inputsd)); + +} // end namespace stats +} // end namespace raft diff --git a/cpp/test/stats/weighted_mean.cu b/cpp/test/stats/weighted_mean.cu index ec99d5a627..9f33855572 100644 --- a/cpp/test/stats/weighted_mean.cu +++ b/cpp/test/stats/weighted_mean.cu @@ -15,7 +15,9 @@ */ #include "../test_utils.h" +#include #include +#include #include #include #include @@ -87,11 +89,23 @@ class RowWeightedMeanTest : public ::testing::TestWithParam hexp(rows); // compute naive result & copy to GPU - naiveRowWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, true); - dexp = hexp; - - // compute result - rowWeightedMean(dact.data().get(), din.data().get(), dweights.data().get(), cols, rows, stream); + naiveRowWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, params.row_major); + dexp = hexp; + auto output = raft::make_device_vector_view(dact.data().get(), rows); + auto weights = + raft::make_device_vector_view(dweights.data().get(), cols); + + if (params.row_major) { + auto input = raft::make_device_matrix_view( + din.data().get(), rows, cols); + // compute result + row_weighted_mean(handle, input, weights, output); + } else { + auto input = raft::make_device_matrix_view( + din.data().get(), rows, cols); + // compute result + row_weighted_mean(handle, input, weights, output); + } // adjust tolerance to account for round-off accumulation params.tolerance *= params.N; @@ -150,12 +164,23 @@ class ColWeightedMeanTest : public ::testing::TestWithParam hexp(cols); // compute naive result & copy to GPU - naiveColWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, true); + naiveColWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, params.row_major); dexp = hexp; - // compute result - colWeightedMean(dact.data().get(), din.data().get(), dweights.data().get(), cols, rows, stream); - + auto output = raft::make_device_vector_view(dact.data().get(), cols); + auto weights = + raft::make_device_vector_view(dweights.data().get(), rows); + if (params.row_major) { + auto input = raft::make_device_matrix_view( + din.data().get(), rows, cols); + // compute result + col_weighted_mean(handle, input, weights, output); + } else { + auto input = raft::make_device_matrix_view( + din.data().get(), rows, cols); + // compute result + col_weighted_mean(handle, input, weights, output); + } // adjust tolerance to account for round-off accumulation params.tolerance *= params.M; } @@ -200,16 +225,20 @@ class WeightedMeanTest : public ::testing::TestWithParam> naiveColWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, params.row_major); dexp = hexp; - // compute result - weightedMean(dact.data().get(), - din.data().get(), - dweights.data().get(), - cols, - rows, - params.row_major, - params.along_rows, - stream); - + auto output = raft::make_device_vector_view(dact.data().get(), mean_size); + auto weights = + raft::make_device_vector_view(dweights.data().get(), weight_size); + if (params.row_major) { + auto input = raft::make_device_matrix_view( + din.data().get(), rows, cols); + // compute result + weighted_mean(handle, input, weights, output, params.along_rows); + } else { + auto input = raft::make_device_matrix_view( + din.data().get(), rows, cols); + // compute result + weighted_mean(handle, input, weights, output, params.along_rows); + } // adjust tolerance to account for round-off accumulation params.tolerance *= params.N; } @@ -226,6 +255,10 @@ static const float tolF = 128 * std::numeric_limits::epsilon(); static const double tolD = 256 * std::numeric_limits::epsilon(); const std::vector> inputsf = {{tolF, 4, 4, 1234, true, true}, + {tolF, 32, 32, 1234, true, false}, + {tolF, 32, 64, 1234, false, false}, + {tolF, 32, 256, 1234, true, true}, + {tolF, 32, 256, 1234, false, false}, {tolF, 1024, 32, 1234, true, false}, {tolF, 1024, 64, 1234, true, true}, {tolF, 1024, 128, 1234, true, false}, @@ -236,6 +269,10 @@ const std::vector> inputsf = {{tolF, 4, 4, 1234, true, {tolF, 1024, 256, 1234, false, true}}; const std::vector> inputsd = {{tolD, 4, 4, 1234, true, true}, + {tolD, 32, 32, 1234, true, false}, + {tolD, 32, 64, 1234, false, false}, + {tolD, 32, 256, 1234, true, true}, + {tolD, 32, 256, 1234, false, false}, {tolD, 1024, 32, 1234, true, false}, {tolD, 1024, 64, 1234, true, true}, {tolD, 1024, 128, 1234, true, false}, @@ -280,16 +317,20 @@ INSTANTIATE_TEST_CASE_P(ColWeightedMeanTest, ColWeightedMeanTestD, ::testing::Va using WeightedMeanTestF = WeightedMeanTest; TEST_P(WeightedMeanTestF, Result) { + auto mean_size = params.along_rows ? params.M : params.N; ASSERT_TRUE(devArrMatch( - dexp.data().get(), dact.data().get(), params.N, raft::CompareApprox(params.tolerance))); + dexp.data().get(), dact.data().get(), mean_size, raft::CompareApprox(params.tolerance))); } INSTANTIATE_TEST_CASE_P(WeightedMeanTest, WeightedMeanTestF, ::testing::ValuesIn(inputsf)); using WeightedMeanTestD = WeightedMeanTest; TEST_P(WeightedMeanTestD, Result) { - ASSERT_TRUE(devArrMatch( - dexp.data().get(), dact.data().get(), params.N, raft::CompareApprox(params.tolerance))); + auto mean_size = params.along_rows ? params.M : params.N; + ASSERT_TRUE(devArrMatch(dexp.data().get(), + dact.data().get(), + mean_size, + raft::CompareApprox(params.tolerance))); } INSTANTIATE_TEST_CASE_P(WeightedMeanTest, WeightedMeanTestD, ::testing::ValuesIn(inputsd));