Skip to content

Commit

Permalink
mdspanifying weighted_mean and add raft::stats tests (#910)
Browse files Browse the repository at this point in the history
Closes #880

Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Robert Maynard (https://github.com/robertmaynard)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #910
  • Loading branch information
lowener authored Oct 13, 2022
1 parent f56daa6 commit ea9a50b
Show file tree
Hide file tree
Showing 6 changed files with 514 additions and 22 deletions.
85 changes: 85 additions & 0 deletions cpp/include/raft/stats/weighted_mean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/stats/detail/weighted_mean.cuh>

namespace raft {
Expand Down Expand Up @@ -91,6 +92,90 @@ void colWeightedMean(
{
weightedMean(mu, data, weights, D, N, true, false, stream);
}

/**
* @brief Compute the weighted mean of the input matrix with a
* vector of weights, along rows or along columns
*
* @tparam value_t the data type
* @tparam idx_t Integer type used to for addressing
* @tparam layout_t Layout type of the input matrix.
* @param[in] handle the raft handle
* @param[in] data the input matrix of size nrows * ncols
* @param[in] weights weight of size ncols if along_row is true, else of size nrows
* @param[out] mu the output mean vector of size nrows if along_row is true, else of size ncols
* @param[in] along_rows whether to reduce along rows or columns
*/
template <typename value_t, typename idx_t, typename layout_t>
void weighted_mean(const raft::handle_t& handle,
raft::device_matrix_view<const value_t, idx_t, layout_t> data,
raft::device_vector_view<const value_t, idx_t> weights,
raft::device_vector_view<value_t, idx_t> mu,
bool along_rows)
{
constexpr bool is_row_major = std::is_same_v<layout_t, raft::row_major>;
constexpr bool is_col_major = std::is_same_v<layout_t, raft::col_major>;
static_assert(is_row_major || is_col_major,
"weighted_mean: Layout must be either "
"raft::row_major or raft::col_major (or one of their aliases)");
auto mean_vec_size = along_rows ? data.extent(0) : data.extent(1);
auto weight_size = along_rows ? data.extent(1) : data.extent(0);

RAFT_EXPECTS(weights.extent(0) == weight_size,
"Size mismatch between weights and expected weight_size");
RAFT_EXPECTS(mu.extent(0) == mean_vec_size, "Size mismatch betwen mu and expected mean_vec_size");

detail::weightedMean(mu.data_handle(),
data.data_handle(),
weights.data_handle(),
data.extent(1),
data.extent(0),
is_row_major,
along_rows,
handle.get_stream());
}

/**
* @brief Compute the row-wise weighted mean of the input matrix with a
* vector of column weights
*
* @tparam value_t the data type
* @tparam idx_t Integer type used to for addressing
* @tparam layout_t Layout type of the input matrix.
* @param[in] handle the raft handle
* @param[in] data the input matrix of size nrows * ncols
* @param[in] weights weight vector of size ncols
* @param[out] mu the output mean vector of size nrows
*/
template <typename value_t, typename idx_t, typename layout_t>
void row_weighted_mean(const raft::handle_t& handle,
raft::device_matrix_view<const value_t, idx_t, layout_t> data,
raft::device_vector_view<const value_t, idx_t> weights,
raft::device_vector_view<value_t, idx_t> mu)
{
weighted_mean(handle, data, weights, mu, true);
}

/**
* @brief Compute the column-wise weighted mean of the input matrix with a
* vector of row weights
*
* @tparam value_t the data type
* @tparam idx_t Integer type used to for addressing
* @tparam layout_t Layout type of the input matrix.
* @param[in] handle the raft handle
* @param[in] data the input matrix of size nrows * ncols
* @param[in] weights weight vector of size nrows
* @param[out] mu the output mean vector of size ncols
*/
template <typename value_t, typename idx_t, typename layout_t>
void col_weighted_mean(const raft::handle_t& handle,
raft::device_matrix_view<const value_t, idx_t, layout_t> data,
raft::device_vector_view<const value_t, idx_t> weights,
raft::device_vector_view<value_t, idx_t> mu)
{
weighted_mean(handle, data, weights, mu, false);
}
}; // end namespace stats
}; // end namespace raft

Expand Down
3 changes: 3 additions & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ if(BUILD_TESTS)

ConfigureTest(NAME STATS_TEST
PATH
test/stats/accuracy.cu
test/stats/adjusted_rand_index.cu
test/stats/completeness_score.cu
test/stats/contingencyMatrix.cu
Expand All @@ -252,7 +253,9 @@ if(BUILD_TESTS)
test/stats/mean_center.cu
test/stats/minmax.cu
test/stats/mutual_info_score.cu
test/stats/r2_score.cu
test/stats/rand_index.cu
test/stats/regression_metrics.cu
test/stats/silhouette_score.cu
test/stats/stddev.cu
test/stats/sum.cu
Expand Down
105 changes: 105 additions & 0 deletions cpp/test/stats/accuracy.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "../test_utils.h"
#include <gtest/gtest.h>
#include <optional>
#include <raft/interruptible.hpp>
#include <raft/random/rng.cuh>
#include <raft/stats/accuracy.cuh>
#include <raft/util/cuda_utils.cuh>
#include <rmm/device_uvector.hpp>
#include <stdio.h>
#include <stdlib.h>
#include <vector>

namespace raft {
namespace stats {

template <typename T>
struct AccuracyInputs {
T tolerance;
int nrows;
unsigned long long int seed;
};

template <typename T>
::std::ostream& operator<<(::std::ostream& os, const AccuracyInputs<T>& dims)
{
return os;
}

template <typename T>
class AccuracyTest : public ::testing::TestWithParam<AccuracyInputs<T>> {
protected:
AccuracyTest() : stream(handle.get_stream()) {}

void SetUp() override
{
params = ::testing::TestWithParam<AccuracyInputs<T>>::GetParam();
raft::random::RngState r(params.seed);
rmm::device_uvector<int> predictions(params.nrows, stream);
rmm::device_uvector<int> ref_predictions(params.nrows, stream);
uniformInt(handle, r, predictions.data(), params.nrows, 0, 10);
uniformInt(handle, r, ref_predictions.data(), params.nrows, 0, 10);

actualVal =
accuracy(handle,
raft::make_device_vector_view<const int>(predictions.data(), params.nrows),
raft::make_device_vector_view<const int>(ref_predictions.data(), params.nrows));
expectedVal = T(0);
std::vector<int> h_predictions(params.nrows, 0);
std::vector<int> h_ref_predictions(params.nrows, 0);
raft::update_host(h_predictions.data(), predictions.data(), params.nrows, stream);
raft::update_host(h_ref_predictions.data(), ref_predictions.data(), params.nrows, stream);

unsigned long long correctly_predicted = 0ULL;
for (int i = 0; i < params.nrows; ++i) {
correctly_predicted += (h_predictions[i] - h_ref_predictions[i]) == 0;
}
expectedVal = correctly_predicted * 1.0f / params.nrows;
raft::interruptible::synchronize(stream);
}

protected:
AccuracyInputs<T> params;
raft::handle_t handle;
cudaStream_t stream = 0;
T expectedVal, actualVal;
};

const std::vector<AccuracyInputs<float>> inputsf = {
{0.001f, 30, 1234ULL}, {0.001f, 100, 1234ULL}, {0.001f, 1000, 1234ULL}};
typedef AccuracyTest<float> AccuracyTestF;
TEST_P(AccuracyTestF, Result)
{
auto eq = raft::CompareApprox<float>(params.tolerance);
ASSERT_TRUE(match(expectedVal, actualVal, eq));
}
INSTANTIATE_TEST_CASE_P(AccuracyTests, AccuracyTestF, ::testing::ValuesIn(inputsf));

const std::vector<AccuracyInputs<double>> inputsd = {
{0.001, 30, 1234ULL}, {0.001, 100, 1234ULL}, {0.001, 1000, 1234ULL}};
typedef AccuracyTest<double> AccuracyTestD;
TEST_P(AccuracyTestD, Result)
{
auto eq = raft::CompareApprox<double>(params.tolerance);
ASSERT_TRUE(match(expectedVal, actualVal, eq));
}
INSTANTIATE_TEST_CASE_P(AccuracyTests, AccuracyTestD, ::testing::ValuesIn(inputsd));

} // end namespace stats
} // end namespace raft
113 changes: 113 additions & 0 deletions cpp/test/stats/r2_score.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "../test_utils.h"
#include <gtest/gtest.h>
#include <optional>
#include <raft/interruptible.hpp>
#include <raft/random/rng.cuh>
#include <raft/stats/r2_score.cuh>
#include <raft/util/cuda_utils.cuh>
#include <rmm/device_uvector.hpp>
#include <stdio.h>
#include <stdlib.h>
#include <vector>

namespace raft {
namespace stats {

template <typename T>
struct R2_scoreInputs {
T tolerance;
int nrows;
unsigned long long int seed;
};

template <typename T>
::std::ostream& operator<<(::std::ostream& os, const R2_scoreInputs<T>& dims)
{
return os;
}

template <typename T>
class R2_scoreTest : public ::testing::TestWithParam<R2_scoreInputs<T>> {
protected:
R2_scoreTest() : stream(handle.get_stream()) {}

void SetUp() override
{
params = ::testing::TestWithParam<R2_scoreInputs<T>>::GetParam();
raft::random::RngState r(params.seed);
rmm::device_uvector<T> y(params.nrows, stream);
rmm::device_uvector<T> y_hat(params.nrows, stream);
uniform(handle, r, y.data(), params.nrows, (T)-1.0, (T)1.0);
uniform(handle, r, y_hat.data(), params.nrows, (T)-1.0, (T)1.0);

actualVal = r2_score(handle,
raft::make_device_vector_view<const T>(y.data(), params.nrows),
raft::make_device_vector_view<const T>(y_hat.data(), params.nrows));
expectedVal = T(0);
std::vector<T> h_y(params.nrows, 0);
std::vector<T> h_y_hat(params.nrows, 0);
raft::update_host(h_y.data(), y.data(), params.nrows, stream);
raft::update_host(h_y_hat.data(), y_hat.data(), params.nrows, stream);
T mean = T(0);
for (int i = 0; i < params.nrows; ++i) {
mean += h_y[i];
}
mean /= params.nrows;

std::vector<T> sse_arr(params.nrows, 0);
std::vector<T> ssto_arr(params.nrows, 0);
T sse = T(0);
T ssto = T(0);
for (int i = 0; i < params.nrows; ++i) {
sse += (h_y[i] - h_y_hat[i]) * (h_y[i] - h_y_hat[i]);
ssto += (h_y[i] - mean) * (h_y[i] - mean);
}
expectedVal = 1.0 - sse / ssto;
raft::interruptible::synchronize(stream);
}

protected:
R2_scoreInputs<T> params;
raft::handle_t handle;
cudaStream_t stream = 0;
T expectedVal, actualVal;
};

const std::vector<R2_scoreInputs<float>> inputsf = {
{0.001f, 30, 1234ULL}, {0.001f, 100, 1234ULL}, {0.001f, 1000, 1234ULL}};
typedef R2_scoreTest<float> R2_scoreTestF;
TEST_P(R2_scoreTestF, Result)
{
auto eq = raft::CompareApprox<float>(params.tolerance);
ASSERT_TRUE(match(expectedVal, actualVal, eq));
}
INSTANTIATE_TEST_CASE_P(R2_scoreTests, R2_scoreTestF, ::testing::ValuesIn(inputsf));

const std::vector<R2_scoreInputs<double>> inputsd = {
{0.001, 30, 1234ULL}, {0.001, 100, 1234ULL}, {0.001, 1000, 1234ULL}};
typedef R2_scoreTest<double> R2_scoreTestD;
TEST_P(R2_scoreTestD, Result)
{
auto eq = raft::CompareApprox<double>(params.tolerance);
ASSERT_TRUE(match(expectedVal, actualVal, eq));
}
INSTANTIATE_TEST_CASE_P(R2_scoreTests, R2_scoreTestD, ::testing::ValuesIn(inputsd));

} // end namespace stats
} // end namespace raft
Loading

0 comments on commit ea9a50b

Please sign in to comment.