From 70beb2f57927408223e475f4fc33bd091b715dc0 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 16 Feb 2022 17:39:52 +0100 Subject: [PATCH 1/6] Add rowSampleWeightedMean --- .../raft/stats/detail/weighted_mean.cuh | 52 ++++++++++++++++++- cpp/include/raft/stats/weighted_mean.hpp | 33 +++++++++++- 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/stats/detail/weighted_mean.cuh b/cpp/include/raft/stats/detail/weighted_mean.cuh index ca7fc136d3..ef16e417b8 100644 --- a/cpp/include/raft/stats/detail/weighted_mean.cuh +++ b/cpp/include/raft/stats/detail/weighted_mean.cuh @@ -18,14 +18,17 @@ #include #include +#include #include +#include namespace raft { namespace stats { namespace detail { /** - * @brief Compute the row-wise weighted mean of the input matrix + * @brief Compute the row-wise weighted mean of the input matrix with a + * vector of column weights * * @tparam Type the data type * @param mu the output mean vector @@ -58,7 +61,52 @@ void rowWeightedMean( } /** - * @brief Compute the column-wise weighted mean of the input matrix + * @brief Compute the row-wise weighted mean of the input matrix with a + * vector of sample weights + * + * @tparam Type the data type + * @param mu the output mean vector + * @param data the input matrix + * @param weights per-sample weight + * @param D number of columns of data + * @param N number of rows of data + * @param row_major input matrix is row-major or not + * @param along_rows whether to reduce along rows or columns + * @param stream cuda stream to launch work on + */ +template +void rowSampleWeightedMean(Type* mu, + const Type* data, + const Type* weights, + int D, + int N, + bool row_major, + bool along_rows, + cudaStream_t stream) +{ + // sum the weights & copy back to CPU + Type WS = 0; + raft::stats::sum(mu, weights, 1, N, row_major, stream); + raft::update_host(&WS, mu, 1, stream); + + raft::linalg::reduce( + mu, + data, + D, + N, + (Type)0, + row_major, + along_rows, + stream, + false, + [weights] __device__(Type v, int i) { return v * weights[i]; }, + [] __device__(Type a, Type b) { return a + b; }, + [WS] __device__(Type v) { return v / WS; }); +} + +/** + * @brief Compute the column-wise weighted mean of the input matrix with a + * vector of column weights * * @tparam Type the data type * @param mu the output mean vector diff --git a/cpp/include/raft/stats/weighted_mean.hpp b/cpp/include/raft/stats/weighted_mean.hpp index ad90142a08..fc285054ac 100644 --- a/cpp/include/raft/stats/weighted_mean.hpp +++ b/cpp/include/raft/stats/weighted_mean.hpp @@ -22,7 +22,8 @@ namespace raft { namespace stats { /** - * @brief Compute the row-wise weighted mean of the input matrix + * @brief Compute the row-wise weighted mean of the input matrix with a + * vector of column weights * * @tparam Type the data type * @param mu the output mean vector @@ -40,7 +41,35 @@ void rowWeightedMean( } /** - * @brief Compute the column-wise weighted mean of the input matrix + * @brief Compute the row-wise weighted mean of the input matrix with a + * vector of sample weights + * + * @tparam Type the data type + * @param mu the output mean vector + * @param data the input matrix + * @param weights per-sample weight + * @param D number of columns of data + * @param N number of rows of data + * @param row_major input matrix is row-major or not + * @param along_rows whether to reduce along rows or columns + * @param stream cuda stream to launch work on + */ +template +void rowSampleWeightedMean(Type* mu, + const Type* data, + const Type* weights, + int D, + int N, + bool row_major, + bool along_rows, + cudaStream_t stream) +{ + detail::rowSampleWeightedMean(mu, data, weights, D, N, row_major, along_rows, stream); +} + +/** + * @brief Compute the column-wise weighted mean of the input matrix with a + * vector of column weights * * @tparam Type the data type * @param mu the output mean vector From df09d47fbf94ade238dd7755cbb59eba07f6a11e Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 23 Feb 2022 20:00:34 +0100 Subject: [PATCH 2/6] Unify weighted mean code --- .../raft/stats/detail/weighted_mean.cuh | 102 +++------------- cpp/include/raft/stats/weighted_mean.hpp | 61 +++++----- cpp/test/stats/weighted_mean.cu | 111 ++++++++++++++---- 3 files changed, 140 insertions(+), 134 deletions(-) diff --git a/cpp/include/raft/stats/detail/weighted_mean.cuh b/cpp/include/raft/stats/detail/weighted_mean.cuh index ef16e417b8..d71b3744ae 100644 --- a/cpp/include/raft/stats/detail/weighted_mean.cuh +++ b/cpp/include/raft/stats/detail/weighted_mean.cuh @@ -18,9 +18,9 @@ #include #include -#include +#include #include -#include +#include namespace raft { namespace stats { @@ -28,65 +28,33 @@ namespace detail { /** * @brief Compute the row-wise weighted mean of the input matrix with a - * vector of column weights - * - * @tparam Type the data type - * @param mu the output mean vector - * @param data the input matrix (assumed to be row-major) - * @param weights per-column means - * @param D number of columns of data - * @param N number of rows of data - * @param stream cuda stream to launch work on - */ -template -void rowWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream) -{ - // sum the weights & copy back to CPU - Type WS = 0; - raft::linalg::coalescedReduction(mu, weights, D, 1, (Type)0, stream, false); - raft::update_host(&WS, mu, 1, stream); - - raft::linalg::coalescedReduction( - mu, - data, - D, - N, - (Type)0, - stream, - false, - [weights] __device__(Type v, int i) { return v * weights[i]; }, - [] __device__(Type a, Type b) { return a + b; }, - [WS] __device__(Type v) { return v / WS; }); -} - -/** - * @brief Compute the row-wise weighted mean of the input matrix with a - * vector of sample weights + * vector of weights * * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing * @param mu the output mean vector * @param data the input matrix - * @param weights per-sample weight + * @param weights weight of size D if along_row is true, else of size N * @param D number of columns of data * @param N number of rows of data - * @param row_major input matrix is row-major or not + * @param row_major data input matrix is row-major or not * @param along_rows whether to reduce along rows or columns * @param stream cuda stream to launch work on */ -template -void rowSampleWeightedMean(Type* mu, - const Type* data, - const Type* weights, - int D, - int N, - bool row_major, - bool along_rows, - cudaStream_t stream) +template +void weightedMean(Type* mu, + const Type* data, + const Type* weights, + IdxType D, + IdxType N, + bool row_major, + bool along_rows, + cudaStream_t stream) { // sum the weights & copy back to CPU + auto weight_size = along_rows ? D : N; Type WS = 0; - raft::stats::sum(mu, weights, 1, N, row_major, stream); + raft::stats::sum(mu, weights, (IdxType)1, weight_size, false, stream); raft::update_host(&WS, mu, 1, stream); raft::linalg::reduce( @@ -99,41 +67,7 @@ void rowSampleWeightedMean(Type* mu, along_rows, stream, false, - [weights] __device__(Type v, int i) { return v * weights[i]; }, - [] __device__(Type a, Type b) { return a + b; }, - [WS] __device__(Type v) { return v / WS; }); -} - -/** - * @brief Compute the column-wise weighted mean of the input matrix with a - * vector of column weights - * - * @tparam Type the data type - * @param mu the output mean vector - * @param data the input matrix (assumed to be column-major) - * @param weights per-column means - * @param D number of columns of data - * @param N number of rows of data - * @param stream cuda stream to launch work on - */ -template -void colWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream) -{ - // sum the weights & copy back to CPU - Type WS = 0; - raft::linalg::stridedReduction(mu, weights, 1, N, (Type)0, stream, false); - raft::update_host(&WS, mu, 1, stream); - - raft::linalg::stridedReduction( - mu, - data, - D, - N, - (Type)0, - stream, - false, - [weights] __device__(Type v, int i) { return v * weights[i]; }, + [weights] __device__(Type v, IdxType i) { return v * weights[i]; }, [] __device__(Type a, Type b) { return a + b; }, [WS] __device__(Type v) { return v / WS; }); } diff --git a/cpp/include/raft/stats/weighted_mean.hpp b/cpp/include/raft/stats/weighted_mean.hpp index fc285054ac..1215119f4b 100644 --- a/cpp/include/raft/stats/weighted_mean.hpp +++ b/cpp/include/raft/stats/weighted_mean.hpp @@ -22,68 +22,71 @@ namespace raft { namespace stats { /** - * @brief Compute the row-wise weighted mean of the input matrix with a - * vector of column weights + * @brief Compute the weighted mean of the input matrix with a + * vector of weights, along rows or along columns * * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing * @param mu the output mean vector - * @param data the input matrix (assumed to be row-major) - * @param weights per-column means + * @param data the input matrix + * @param weights weight of size D if along_row is true, else of size N * @param D number of columns of data * @param N number of rows of data + * @param row_major data input matrix is row-major or not + * @param along_rows whether to reduce along rows or columns * @param stream cuda stream to launch work on */ -template -void rowWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream) +template +void weightedMean(Type* mu, + const Type* data, + const Type* weights, + IdxType D, + IdxType N, + bool row_major, + bool along_rows, + cudaStream_t stream) { - detail::rowWeightedMean(mu, data, weights, D, N, stream); + detail::weightedMean(mu, data, weights, D, N, row_major, along_rows, stream); } /** * @brief Compute the row-wise weighted mean of the input matrix with a - * vector of sample weights + * vector of column weights * * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing * @param mu the output mean vector - * @param data the input matrix - * @param weights per-sample weight + * @param data the input matrix (assumed to be row-major) + * @param weights per-column means * @param D number of columns of data * @param N number of rows of data - * @param row_major input matrix is row-major or not - * @param along_rows whether to reduce along rows or columns * @param stream cuda stream to launch work on */ -template -void rowSampleWeightedMean(Type* mu, - const Type* data, - const Type* weights, - int D, - int N, - bool row_major, - bool along_rows, - cudaStream_t stream) +template +void rowWeightedMean( + Type* mu, const Type* data, const Type* weights, IdxType D, IdxType N, cudaStream_t stream) { - detail::rowSampleWeightedMean(mu, data, weights, D, N, row_major, along_rows, stream); + weightedMean(mu, data, weights, D, N, true, true, stream); } /** * @brief Compute the column-wise weighted mean of the input matrix with a - * vector of column weights + * vector of row weights * * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing * @param mu the output mean vector - * @param data the input matrix (assumed to be column-major) - * @param weights per-column means + * @param data the input matrix (assumed to be row-major) + * @param weights per-row means * @param D number of columns of data * @param N number of rows of data * @param stream cuda stream to launch work on */ -template +template void colWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream) + Type* mu, const Type* data, const Type* weights, IdxType D, IdxType N, cudaStream_t stream) { - detail::colWeightedMean(mu, data, weights, D, N, stream); + weightedMean(mu, data, weights, D, N, true, false, stream); } }; // end namespace stats }; // end namespace raft diff --git a/cpp/test/stats/weighted_mean.cu b/cpp/test/stats/weighted_mean.cu index ee58747b69..1952c4ac71 100644 --- a/cpp/test/stats/weighted_mean.cu +++ b/cpp/test/stats/weighted_mean.cu @@ -30,13 +30,14 @@ struct WeightedMeanInputs { T tolerance; int M, N; unsigned long long int seed; + bool along_rows; // Used only for the weightedMean function }; template ::std::ostream& operator<<(::std::ostream& os, const WeightedMeanInputs& I) { - return os << "{ " << I.tolerance << ", " << I.M << ", " << I.N << ", " << I.seed << "}" - << std::endl; + return os << "{ " << I.tolerance << ", " << I.M << ", " << I.N << ", " << I.seed + << ", " << I.along_rows << "}" << std::endl; } ///// weighted row-wise mean test and support functions @@ -171,29 +172,81 @@ class ColWeightedMeanTest : public ::testing::TestWithParam din, dweights, dexp, dact; }; +template +class WeightedMeanTest : public ::testing::TestWithParam> { + protected: + void SetUp() override + { + params = ::testing::TestWithParam>::GetParam(); + raft::random::Rng r(params.seed); + int rows = params.M, cols = params.N, len = rows * cols; + auto weight_size = params.along_rows ? cols : rows; + auto mean_size = params.along_rows ? rows : cols; + cudaStream_t stream = 0; + RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + // device-side data + din.resize(len); + dweights.resize(weight_size); + dexp.resize(mean_size); + dact.resize(mean_size); + + // create random matrix and weights + r.uniform(din.data().get(), len, T(-1.0), T(1.0), stream); + r.uniform(dweights.data().get(), weight_size, T(-1.0), T(1.0), stream); + + // host-side data + thrust::host_vector hin = din; + thrust::host_vector hweights = dweights; + thrust::host_vector hexp(mean_size); + + // compute naive result & copy to GPU + if (params.along_rows) + naiveRowWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, false); + else + naiveColWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, false); + dexp = hexp; + + // compute ml-prims result + weightedMean(dact.data().get(), din.data().get(), dweights.data().get(), cols, rows, + false, params.along_rows, stream); + + // adjust tolerance to account for round-off accumulation + params.tolerance *= params.N; + RAFT_CUDA_TRY(cudaStreamDestroy(stream)); + } + + void TearDown() override {} + + protected: + WeightedMeanInputs params; + thrust::host_vector hin, hweights; + thrust::device_vector din, dweights, dexp, dact; +}; + + ////// Parameter sets and test instantiation static const float tolF = 128 * std::numeric_limits::epsilon(); static const double tolD = 256 * std::numeric_limits::epsilon(); -const std::vector> inputsf = {{tolF, 4, 4, 1234}, - {tolF, 1024, 32, 1234}, - {tolF, 1024, 64, 1234}, - {tolF, 1024, 128, 1234}, - {tolF, 1024, 256, 1234}, - {tolF, 1024, 32, 1234}, - {tolF, 1024, 64, 1234}, - {tolF, 1024, 128, 1234}, - {tolF, 1024, 256, 1234}}; - -const std::vector> inputsd = {{tolD, 4, 4, 1234}, - {tolD, 1024, 32, 1234}, - {tolD, 1024, 64, 1234}, - {tolD, 1024, 128, 1234}, - {tolD, 1024, 256, 1234}, - {tolD, 1024, 32, 1234}, - {tolD, 1024, 64, 1234}, - {tolD, 1024, 128, 1234}, - {tolD, 1024, 256, 1234}}; +const std::vector> inputsf = {{tolF, 4, 4, 1234, true}, + {tolF, 1024, 32, 1234, true}, + {tolF, 1024, 64, 1234, true}, + {tolF, 1024, 128, 1234, true}, + {tolF, 1024, 256, 1234, true}, + {tolF, 1024, 32, 1234, false}, + {tolF, 1024, 64, 1234, false}, + {tolF, 1024, 128, 1234, false}, + {tolF, 1024, 256, 1234, false}}; + +const std::vector> inputsd = {{tolD, 4, 4, 1234, true}, + {tolD, 1024, 32, 1234, true}, + {tolD, 1024, 64, 1234, true}, + {tolD, 1024, 128, 1234, true}, + {tolD, 1024, 256, 1234, true}, + {tolD, 1024, 32, 1234, false}, + {tolD, 1024, 64, 1234, false}, + {tolD, 1024, 128, 1234, false}, + {tolD, 1024, 256, 1234, false}}; using RowWeightedMeanTestF = RowWeightedMeanTest; TEST_P(RowWeightedMeanTestF, Result) @@ -227,5 +280,21 @@ TEST_P(ColWeightedMeanTestD, Result) } INSTANTIATE_TEST_CASE_P(ColWeightedMeanTest, ColWeightedMeanTestD, ::testing::ValuesIn(inputsd)); +using WeightedMeanTestF = WeightedMeanTest; +TEST_P(WeightedMeanTestF, Result) +{ + ASSERT_TRUE(devArrMatch( + dexp.data().get(), dact.data().get(), params.N, raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(WeightedMeanTest, WeightedMeanTestF, ::testing::ValuesIn(inputsf)); + +using WeightedMeanTestD = WeightedMeanTest; +TEST_P(WeightedMeanTestD, Result) +{ + ASSERT_TRUE(devArrMatch( + dexp.data().get(), dact.data().get(), params.N, raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(WeightedMeanTest, WeightedMeanTestD, ::testing::ValuesIn(inputsd)); + }; // end namespace stats }; // end namespace raft From e08bf6c37f00bd9b73136ab8ebe49ab4c1a0bc4d Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 23 Feb 2022 20:13:37 +0100 Subject: [PATCH 3/6] Fix style --- .../raft/stats/detail/weighted_mean.cuh | 2 +- cpp/test/stats/weighted_mean.cu | 21 ++++++++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/cpp/include/raft/stats/detail/weighted_mean.cuh b/cpp/include/raft/stats/detail/weighted_mean.cuh index d71b3744ae..c8f5898fda 100644 --- a/cpp/include/raft/stats/detail/weighted_mean.cuh +++ b/cpp/include/raft/stats/detail/weighted_mean.cuh @@ -53,7 +53,7 @@ void weightedMean(Type* mu, { // sum the weights & copy back to CPU auto weight_size = along_rows ? D : N; - Type WS = 0; + Type WS = 0; raft::stats::sum(mu, weights, (IdxType)1, weight_size, false, stream); raft::update_host(&WS, mu, 1, stream); diff --git a/cpp/test/stats/weighted_mean.cu b/cpp/test/stats/weighted_mean.cu index 1952c4ac71..4ad5e6795d 100644 --- a/cpp/test/stats/weighted_mean.cu +++ b/cpp/test/stats/weighted_mean.cu @@ -30,14 +30,14 @@ struct WeightedMeanInputs { T tolerance; int M, N; unsigned long long int seed; - bool along_rows; // Used only for the weightedMean function + bool along_rows; // Used only for the weightedMean test function }; template ::std::ostream& operator<<(::std::ostream& os, const WeightedMeanInputs& I) { - return os << "{ " << I.tolerance << ", " << I.M << ", " << I.N << ", " << I.seed - << ", " << I.along_rows << "}" << std::endl; + return os << "{ " << I.tolerance << ", " << I.M << ", " << I.N << ", " << I.seed << ", " + << I.along_rows << "}" << std::endl; } ///// weighted row-wise mean test and support functions @@ -180,8 +180,8 @@ class WeightedMeanTest : public ::testing::TestWithParam> params = ::testing::TestWithParam>::GetParam(); raft::random::Rng r(params.seed); int rows = params.M, cols = params.N, len = rows * cols; - auto weight_size = params.along_rows ? cols : rows; - auto mean_size = params.along_rows ? rows : cols; + auto weight_size = params.along_rows ? cols : rows; + auto mean_size = params.along_rows ? rows : cols; cudaStream_t stream = 0; RAFT_CUDA_TRY(cudaStreamCreate(&stream)); // device-side data @@ -207,8 +207,14 @@ class WeightedMeanTest : public ::testing::TestWithParam> dexp = hexp; // compute ml-prims result - weightedMean(dact.data().get(), din.data().get(), dweights.data().get(), cols, rows, - false, params.along_rows, stream); + weightedMean(dact.data().get(), + din.data().get(), + dweights.data().get(), + cols, + rows, + false, + params.along_rows, + stream); // adjust tolerance to account for round-off accumulation params.tolerance *= params.N; @@ -223,7 +229,6 @@ class WeightedMeanTest : public ::testing::TestWithParam> thrust::device_vector din, dweights, dexp, dact; }; - ////// Parameter sets and test instantiation static const float tolF = 128 * std::numeric_limits::epsilon(); static const double tolD = 256 * std::numeric_limits::epsilon(); From 739bf801ddb49edee667e104b204ab12831c878a Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 25 Feb 2022 21:07:49 +0100 Subject: [PATCH 4/6] Address review --- cpp/test/stats/weighted_mean.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/test/stats/weighted_mean.cu b/cpp/test/stats/weighted_mean.cu index 4ad5e6795d..58a2c819ed 100644 --- a/cpp/test/stats/weighted_mean.cu +++ b/cpp/test/stats/weighted_mean.cu @@ -90,7 +90,7 @@ class RowWeightedMeanTest : public ::testing::TestWithParam> naiveColWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, false); dexp = hexp; - // compute ml-prims result + // compute result weightedMean(dact.data().get(), din.data().get(), dweights.data().get(), cols, rows, - false, + false, // row_major=true is already tested through col and row weighted mean params.along_rows, stream); From 987b6f48854013f1025e46620fa0db971b0b11ce Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 1 Mar 2022 23:20:45 +0100 Subject: [PATCH 5/6] Add test for row_major in weighted mean --- cpp/test/stats/weighted_mean.cu | 45 +++++++++++++++++---------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/cpp/test/stats/weighted_mean.cu b/cpp/test/stats/weighted_mean.cu index f85774fdcd..d78175fc21 100644 --- a/cpp/test/stats/weighted_mean.cu +++ b/cpp/test/stats/weighted_mean.cu @@ -31,6 +31,7 @@ struct WeightedMeanInputs { int M, N; unsigned long long int seed; bool along_rows; // Used only for the weightedMean test function + bool row_major; }; template @@ -201,9 +202,9 @@ class WeightedMeanTest : public ::testing::TestWithParam> // compute naive result & copy to GPU if (params.along_rows) - naiveRowWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, false); + naiveRowWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, params.row_major); else - naiveColWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, false); + naiveColWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, params.row_major); dexp = hexp; // compute result @@ -212,7 +213,7 @@ class WeightedMeanTest : public ::testing::TestWithParam> dweights.data().get(), cols, rows, - false, // row_major=true is already tested through col and row weighted mean + params.row_major, params.along_rows, stream); @@ -233,25 +234,25 @@ class WeightedMeanTest : public ::testing::TestWithParam> static const float tolF = 128 * std::numeric_limits::epsilon(); static const double tolD = 256 * std::numeric_limits::epsilon(); -const std::vector> inputsf = {{tolF, 4, 4, 1234, true}, - {tolF, 1024, 32, 1234, true}, - {tolF, 1024, 64, 1234, true}, - {tolF, 1024, 128, 1234, true}, - {tolF, 1024, 256, 1234, true}, - {tolF, 1024, 32, 1234, false}, - {tolF, 1024, 64, 1234, false}, - {tolF, 1024, 128, 1234, false}, - {tolF, 1024, 256, 1234, false}}; - -const std::vector> inputsd = {{tolD, 4, 4, 1234, true}, - {tolD, 1024, 32, 1234, true}, - {tolD, 1024, 64, 1234, true}, - {tolD, 1024, 128, 1234, true}, - {tolD, 1024, 256, 1234, true}, - {tolD, 1024, 32, 1234, false}, - {tolD, 1024, 64, 1234, false}, - {tolD, 1024, 128, 1234, false}, - {tolD, 1024, 256, 1234, false}}; +const std::vector> inputsf = {{tolF, 4, 4, 1234, true, true}, + {tolF, 1024, 32, 1234, true, false}, + {tolF, 1024, 64, 1234, true, true}, + {tolF, 1024, 128, 1234, true, false}, + {tolF, 1024, 256, 1234, true, true}, + {tolF, 1024, 32, 1234, false, false}, + {tolF, 1024, 64, 1234, false, true}, + {tolF, 1024, 128, 1234, false, false}, + {tolF, 1024, 256, 1234, false, true}}; + +const std::vector> inputsd = {{tolD, 4, 4, 1234, true, true}, + {tolD, 1024, 32, 1234, true, false}, + {tolD, 1024, 64, 1234, true, true}, + {tolD, 1024, 128, 1234, true, false}, + {tolD, 1024, 256, 1234, true, true}, + {tolD, 1024, 32, 1234, false, false}, + {tolD, 1024, 64, 1234, false, true}, + {tolD, 1024, 128, 1234, false, false}, + {tolD, 1024, 256, 1234, false, true}}; using RowWeightedMeanTestF = RowWeightedMeanTest; TEST_P(RowWeightedMeanTestF, Result) From cfbd3d0809f9673573525d1c0196f863e9f0cf06 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 2 Mar 2022 01:16:42 +0100 Subject: [PATCH 6/6] mirror change to cuh file --- cpp/include/raft/stats/weighted_mean.cuh | 52 +++++++++++++++++++----- 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/stats/weighted_mean.cuh b/cpp/include/raft/stats/weighted_mean.cuh index fe54d927ca..0e8338fe84 100644 --- a/cpp/include/raft/stats/weighted_mean.cuh +++ b/cpp/include/raft/stats/weighted_mean.cuh @@ -25,9 +25,39 @@ namespace raft { namespace stats { /** - * @brief Compute the row-wise weighted mean of the input matrix + * @brief Compute the weighted mean of the input matrix with a + * vector of weights, along rows or along columns * * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing + * @param mu the output mean vector + * @param data the input matrix + * @param weights weight of size D if along_row is true, else of size N + * @param D number of columns of data + * @param N number of rows of data + * @param row_major data input matrix is row-major or not + * @param along_rows whether to reduce along rows or columns + * @param stream cuda stream to launch work on + */ +template +void weightedMean(Type* mu, + const Type* data, + const Type* weights, + IdxType D, + IdxType N, + bool row_major, + bool along_rows, + cudaStream_t stream) +{ + detail::weightedMean(mu, data, weights, D, N, row_major, along_rows, stream); +} + +/** + * @brief Compute the row-wise weighted mean of the input matrix with a + * vector of column weights + * + * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing * @param mu the output mean vector * @param data the input matrix (assumed to be row-major) * @param weights per-column means @@ -35,29 +65,31 @@ namespace stats { * @param N number of rows of data * @param stream cuda stream to launch work on */ -template +template void rowWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream) + Type* mu, const Type* data, const Type* weights, IdxType D, IdxType N, cudaStream_t stream) { - detail::rowWeightedMean(mu, data, weights, D, N, stream); + weightedMean(mu, data, weights, D, N, true, true, stream); } /** - * @brief Compute the column-wise weighted mean of the input matrix + * @brief Compute the column-wise weighted mean of the input matrix with a + * vector of row weights * * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing * @param mu the output mean vector - * @param data the input matrix (assumed to be column-major) - * @param weights per-column means + * @param data the input matrix (assumed to be row-major) + * @param weights per-row means * @param D number of columns of data * @param N number of rows of data * @param stream cuda stream to launch work on */ -template +template void colWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream) + Type* mu, const Type* data, const Type* weights, IdxType D, IdxType N, cudaStream_t stream) { - detail::colWeightedMean(mu, data, weights, D, N, stream); + weightedMean(mu, data, weights, D, N, true, false, stream); } }; // end namespace stats }; // end namespace raft