diff --git a/cpp/include/raft/stats/detail/weighted_mean.cuh b/cpp/include/raft/stats/detail/weighted_mean.cuh index 0069cf0a3f..6d6f901fab 100644 --- a/cpp/include/raft/stats/detail/weighted_mean.cuh +++ b/cpp/include/raft/stats/detail/weighted_mean.cuh @@ -17,75 +17,55 @@ #pragma once #include -#include -#include +#include +#include namespace raft { namespace stats { namespace detail { /** - * @brief Compute the row-wise weighted mean of the input matrix + * @brief Compute the row-wise weighted mean of the input matrix with a + * vector of weights * * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing * @param mu the output mean vector - * @param data the input matrix (assumed to be row-major) - * @param weights per-column means + * @param data the input matrix + * @param weights weight of size D if along_row is true, else of size N * @param D number of columns of data * @param N number of rows of data + * @param row_major data input matrix is row-major or not + * @param along_rows whether to reduce along rows or columns * @param stream cuda stream to launch work on */ -template -void rowWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream) +template +void weightedMean(Type* mu, + const Type* data, + const Type* weights, + IdxType D, + IdxType N, + bool row_major, + bool along_rows, + cudaStream_t stream) { // sum the weights & copy back to CPU - Type WS = 0; - raft::linalg::coalescedReduction(mu, weights, D, 1, (Type)0, stream, false); + auto weight_size = along_rows ? D : N; + Type WS = 0; + raft::stats::sum(mu, weights, (IdxType)1, weight_size, false, stream); raft::update_host(&WS, mu, 1, stream); - raft::linalg::coalescedReduction( + raft::linalg::reduce( mu, data, D, N, (Type)0, + row_major, + along_rows, stream, false, - [weights] __device__(Type v, int i) { return v * weights[i]; }, - [] __device__(Type a, Type b) { return a + b; }, - [WS] __device__(Type v) { return v / WS; }); -} - -/** - * @brief Compute the column-wise weighted mean of the input matrix - * - * @tparam Type the data type - * @param mu the output mean vector - * @param data the input matrix (assumed to be column-major) - * @param weights per-column means - * @param D number of columns of data - * @param N number of rows of data - * @param stream cuda stream to launch work on - */ -template -void colWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream) -{ - // sum the weights & copy back to CPU - Type WS = 0; - raft::linalg::stridedReduction(mu, weights, 1, N, (Type)0, stream, false); - raft::update_host(&WS, mu, 1, stream); - - raft::linalg::stridedReduction( - mu, - data, - D, - N, - (Type)0, - stream, - false, - [weights] __device__(Type v, int i) { return v * weights[i]; }, + [weights] __device__(Type v, IdxType i) { return v * weights[i]; }, [] __device__(Type a, Type b) { return a + b; }, [WS] __device__(Type v) { return v / WS; }); } diff --git a/cpp/include/raft/stats/weighted_mean.cuh b/cpp/include/raft/stats/weighted_mean.cuh index fe54d927ca..0e8338fe84 100644 --- a/cpp/include/raft/stats/weighted_mean.cuh +++ b/cpp/include/raft/stats/weighted_mean.cuh @@ -25,9 +25,39 @@ namespace raft { namespace stats { /** - * @brief Compute the row-wise weighted mean of the input matrix + * @brief Compute the weighted mean of the input matrix with a + * vector of weights, along rows or along columns * * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing + * @param mu the output mean vector + * @param data the input matrix + * @param weights weight of size D if along_row is true, else of size N + * @param D number of columns of data + * @param N number of rows of data + * @param row_major data input matrix is row-major or not + * @param along_rows whether to reduce along rows or columns + * @param stream cuda stream to launch work on + */ +template +void weightedMean(Type* mu, + const Type* data, + const Type* weights, + IdxType D, + IdxType N, + bool row_major, + bool along_rows, + cudaStream_t stream) +{ + detail::weightedMean(mu, data, weights, D, N, row_major, along_rows, stream); +} + +/** + * @brief Compute the row-wise weighted mean of the input matrix with a + * vector of column weights + * + * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing * @param mu the output mean vector * @param data the input matrix (assumed to be row-major) * @param weights per-column means @@ -35,29 +65,31 @@ namespace stats { * @param N number of rows of data * @param stream cuda stream to launch work on */ -template +template void rowWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream) + Type* mu, const Type* data, const Type* weights, IdxType D, IdxType N, cudaStream_t stream) { - detail::rowWeightedMean(mu, data, weights, D, N, stream); + weightedMean(mu, data, weights, D, N, true, true, stream); } /** - * @brief Compute the column-wise weighted mean of the input matrix + * @brief Compute the column-wise weighted mean of the input matrix with a + * vector of row weights * * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing * @param mu the output mean vector - * @param data the input matrix (assumed to be column-major) - * @param weights per-column means + * @param data the input matrix (assumed to be row-major) + * @param weights per-row means * @param D number of columns of data * @param N number of rows of data * @param stream cuda stream to launch work on */ -template +template void colWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream) + Type* mu, const Type* data, const Type* weights, IdxType D, IdxType N, cudaStream_t stream) { - detail::colWeightedMean(mu, data, weights, D, N, stream); + weightedMean(mu, data, weights, D, N, true, false, stream); } }; // end namespace stats }; // end namespace raft diff --git a/cpp/include/raft/stats/weighted_mean.hpp b/cpp/include/raft/stats/weighted_mean.hpp index 6d2fd1e928..4f53067e65 100644 --- a/cpp/include/raft/stats/weighted_mean.hpp +++ b/cpp/include/raft/stats/weighted_mean.hpp @@ -29,9 +29,39 @@ namespace raft { namespace stats { /** - * @brief Compute the row-wise weighted mean of the input matrix + * @brief Compute the weighted mean of the input matrix with a + * vector of weights, along rows or along columns * * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing + * @param mu the output mean vector + * @param data the input matrix + * @param weights weight of size D if along_row is true, else of size N + * @param D number of columns of data + * @param N number of rows of data + * @param row_major data input matrix is row-major or not + * @param along_rows whether to reduce along rows or columns + * @param stream cuda stream to launch work on + */ +template +void weightedMean(Type* mu, + const Type* data, + const Type* weights, + IdxType D, + IdxType N, + bool row_major, + bool along_rows, + cudaStream_t stream) +{ + detail::weightedMean(mu, data, weights, D, N, row_major, along_rows, stream); +} + +/** + * @brief Compute the row-wise weighted mean of the input matrix with a + * vector of column weights + * + * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing * @param mu the output mean vector * @param data the input matrix (assumed to be row-major) * @param weights per-column means @@ -39,29 +69,31 @@ namespace stats { * @param N number of rows of data * @param stream cuda stream to launch work on */ -template +template void rowWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream) + Type* mu, const Type* data, const Type* weights, IdxType D, IdxType N, cudaStream_t stream) { - detail::rowWeightedMean(mu, data, weights, D, N, stream); + weightedMean(mu, data, weights, D, N, true, true, stream); } /** - * @brief Compute the column-wise weighted mean of the input matrix + * @brief Compute the column-wise weighted mean of the input matrix with a + * vector of row weights * * @tparam Type the data type + * @tparam IdxType Integer type used to for addressing * @param mu the output mean vector - * @param data the input matrix (assumed to be column-major) - * @param weights per-column means + * @param data the input matrix (assumed to be row-major) + * @param weights per-row means * @param D number of columns of data * @param N number of rows of data * @param stream cuda stream to launch work on */ -template +template void colWeightedMean( - Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream) + Type* mu, const Type* data, const Type* weights, IdxType D, IdxType N, cudaStream_t stream) { - detail::colWeightedMean(mu, data, weights, D, N, stream); + weightedMean(mu, data, weights, D, N, true, false, stream); } }; // end namespace stats }; // end namespace raft diff --git a/cpp/test/stats/weighted_mean.cu b/cpp/test/stats/weighted_mean.cu index dc67947a27..d78175fc21 100644 --- a/cpp/test/stats/weighted_mean.cu +++ b/cpp/test/stats/weighted_mean.cu @@ -30,13 +30,15 @@ struct WeightedMeanInputs { T tolerance; int M, N; unsigned long long int seed; + bool along_rows; // Used only for the weightedMean test function + bool row_major; }; template ::std::ostream& operator<<(::std::ostream& os, const WeightedMeanInputs& I) { - return os << "{ " << I.tolerance << ", " << I.M << ", " << I.N << ", " << I.seed << "}" - << std::endl; + return os << "{ " << I.tolerance << ", " << I.M << ", " << I.N << ", " << I.seed << ", " + << I.along_rows << "}" << std::endl; } ///// weighted row-wise mean test and support functions @@ -89,7 +91,7 @@ class RowWeightedMeanTest : public ::testing::TestWithParam din, dweights, dexp, dact; }; +template +class WeightedMeanTest : public ::testing::TestWithParam> { + protected: + void SetUp() override + { + params = ::testing::TestWithParam>::GetParam(); + raft::random::Rng r(params.seed); + int rows = params.M, cols = params.N, len = rows * cols; + auto weight_size = params.along_rows ? cols : rows; + auto mean_size = params.along_rows ? rows : cols; + cudaStream_t stream = 0; + RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + // device-side data + din.resize(len); + dweights.resize(weight_size); + dexp.resize(mean_size); + dact.resize(mean_size); + + // create random matrix and weights + r.uniform(din.data().get(), len, T(-1.0), T(1.0), stream); + r.uniform(dweights.data().get(), weight_size, T(-1.0), T(1.0), stream); + + // host-side data + thrust::host_vector hin = din; + thrust::host_vector hweights = dweights; + thrust::host_vector hexp(mean_size); + + // compute naive result & copy to GPU + if (params.along_rows) + naiveRowWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, params.row_major); + else + naiveColWeightedMean(hexp.data(), hin.data(), hweights.data(), rows, cols, params.row_major); + dexp = hexp; + + // compute result + weightedMean(dact.data().get(), + din.data().get(), + dweights.data().get(), + cols, + rows, + params.row_major, + params.along_rows, + stream); + + // adjust tolerance to account for round-off accumulation + params.tolerance *= params.N; + RAFT_CUDA_TRY(cudaStreamDestroy(stream)); + } + + void TearDown() override {} + + protected: + WeightedMeanInputs params; + thrust::host_vector hin, hweights; + thrust::device_vector din, dweights, dexp, dact; +}; + ////// Parameter sets and test instantiation static const float tolF = 128 * std::numeric_limits::epsilon(); static const double tolD = 256 * std::numeric_limits::epsilon(); -const std::vector> inputsf = {{tolF, 4, 4, 1234}, - {tolF, 1024, 32, 1234}, - {tolF, 1024, 64, 1234}, - {tolF, 1024, 128, 1234}, - {tolF, 1024, 256, 1234}, - {tolF, 1024, 32, 1234}, - {tolF, 1024, 64, 1234}, - {tolF, 1024, 128, 1234}, - {tolF, 1024, 256, 1234}}; - -const std::vector> inputsd = {{tolD, 4, 4, 1234}, - {tolD, 1024, 32, 1234}, - {tolD, 1024, 64, 1234}, - {tolD, 1024, 128, 1234}, - {tolD, 1024, 256, 1234}, - {tolD, 1024, 32, 1234}, - {tolD, 1024, 64, 1234}, - {tolD, 1024, 128, 1234}, - {tolD, 1024, 256, 1234}}; +const std::vector> inputsf = {{tolF, 4, 4, 1234, true, true}, + {tolF, 1024, 32, 1234, true, false}, + {tolF, 1024, 64, 1234, true, true}, + {tolF, 1024, 128, 1234, true, false}, + {tolF, 1024, 256, 1234, true, true}, + {tolF, 1024, 32, 1234, false, false}, + {tolF, 1024, 64, 1234, false, true}, + {tolF, 1024, 128, 1234, false, false}, + {tolF, 1024, 256, 1234, false, true}}; + +const std::vector> inputsd = {{tolD, 4, 4, 1234, true, true}, + {tolD, 1024, 32, 1234, true, false}, + {tolD, 1024, 64, 1234, true, true}, + {tolD, 1024, 128, 1234, true, false}, + {tolD, 1024, 256, 1234, true, true}, + {tolD, 1024, 32, 1234, false, false}, + {tolD, 1024, 64, 1234, false, true}, + {tolD, 1024, 128, 1234, false, false}, + {tolD, 1024, 256, 1234, false, true}}; using RowWeightedMeanTestF = RowWeightedMeanTest; TEST_P(RowWeightedMeanTestF, Result) @@ -227,5 +286,21 @@ TEST_P(ColWeightedMeanTestD, Result) } INSTANTIATE_TEST_CASE_P(ColWeightedMeanTest, ColWeightedMeanTestD, ::testing::ValuesIn(inputsd)); +using WeightedMeanTestF = WeightedMeanTest; +TEST_P(WeightedMeanTestF, Result) +{ + ASSERT_TRUE(devArrMatch( + dexp.data().get(), dact.data().get(), params.N, raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(WeightedMeanTest, WeightedMeanTestF, ::testing::ValuesIn(inputsf)); + +using WeightedMeanTestD = WeightedMeanTest; +TEST_P(WeightedMeanTestD, Result) +{ + ASSERT_TRUE(devArrMatch( + dexp.data().get(), dact.data().get(), params.N, raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(WeightedMeanTest, WeightedMeanTestD, ::testing::ValuesIn(inputsd)); + }; // end namespace stats }; // end namespace raft