Skip to content

Commit

Permalink
Unify weighted mean code (#514)
Browse files Browse the repository at this point in the history
Needed for rapidsai/cuml#4428

Authors:
  - Micka (https://github.com/lowener)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #514
  • Loading branch information
lowener authored Mar 2, 2022
1 parent 091e2ac commit a6f3caf
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 88 deletions.
70 changes: 25 additions & 45 deletions cpp/include/raft/stats/detail/weighted_mean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,75 +17,55 @@
#pragma once

#include <raft/cudart_utils.h>
#include <raft/linalg/coalesced_reduction.cuh>
#include <raft/linalg/strided_reduction.cuh>
#include <raft/linalg/reduce.hpp>
#include <raft/stats/sum.hpp>

namespace raft {
namespace stats {
namespace detail {

/**
* @brief Compute the row-wise weighted mean of the input matrix
* @brief Compute the row-wise weighted mean of the input matrix with a
* vector of weights
*
* @tparam Type the data type
* @tparam IdxType Integer type used to for addressing
* @param mu the output mean vector
* @param data the input matrix (assumed to be row-major)
* @param weights per-column means
* @param data the input matrix
* @param weights weight of size D if along_row is true, else of size N
* @param D number of columns of data
* @param N number of rows of data
* @param row_major data input matrix is row-major or not
* @param along_rows whether to reduce along rows or columns
* @param stream cuda stream to launch work on
*/
template <typename Type>
void rowWeightedMean(
Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream)
template <typename Type, typename IdxType = int>
void weightedMean(Type* mu,
const Type* data,
const Type* weights,
IdxType D,
IdxType N,
bool row_major,
bool along_rows,
cudaStream_t stream)
{
// sum the weights & copy back to CPU
Type WS = 0;
raft::linalg::coalescedReduction(mu, weights, D, 1, (Type)0, stream, false);
auto weight_size = along_rows ? D : N;
Type WS = 0;
raft::stats::sum(mu, weights, (IdxType)1, weight_size, false, stream);
raft::update_host(&WS, mu, 1, stream);

raft::linalg::coalescedReduction(
raft::linalg::reduce(
mu,
data,
D,
N,
(Type)0,
row_major,
along_rows,
stream,
false,
[weights] __device__(Type v, int i) { return v * weights[i]; },
[] __device__(Type a, Type b) { return a + b; },
[WS] __device__(Type v) { return v / WS; });
}

/**
* @brief Compute the column-wise weighted mean of the input matrix
*
* @tparam Type the data type
* @param mu the output mean vector
* @param data the input matrix (assumed to be column-major)
* @param weights per-column means
* @param D number of columns of data
* @param N number of rows of data
* @param stream cuda stream to launch work on
*/
template <typename Type>
void colWeightedMean(
Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream)
{
// sum the weights & copy back to CPU
Type WS = 0;
raft::linalg::stridedReduction(mu, weights, 1, N, (Type)0, stream, false);
raft::update_host(&WS, mu, 1, stream);

raft::linalg::stridedReduction(
mu,
data,
D,
N,
(Type)0,
stream,
false,
[weights] __device__(Type v, int i) { return v * weights[i]; },
[weights] __device__(Type v, IdxType i) { return v * weights[i]; },
[] __device__(Type a, Type b) { return a + b; },
[WS] __device__(Type v) { return v / WS; });
}
Expand Down
52 changes: 42 additions & 10 deletions cpp/include/raft/stats/weighted_mean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,71 @@ namespace raft {
namespace stats {

/**
* @brief Compute the row-wise weighted mean of the input matrix
* @brief Compute the weighted mean of the input matrix with a
* vector of weights, along rows or along columns
*
* @tparam Type the data type
* @tparam IdxType Integer type used to for addressing
* @param mu the output mean vector
* @param data the input matrix
* @param weights weight of size D if along_row is true, else of size N
* @param D number of columns of data
* @param N number of rows of data
* @param row_major data input matrix is row-major or not
* @param along_rows whether to reduce along rows or columns
* @param stream cuda stream to launch work on
*/
template <typename Type, typename IdxType = int>
void weightedMean(Type* mu,
const Type* data,
const Type* weights,
IdxType D,
IdxType N,
bool row_major,
bool along_rows,
cudaStream_t stream)
{
detail::weightedMean(mu, data, weights, D, N, row_major, along_rows, stream);
}

/**
* @brief Compute the row-wise weighted mean of the input matrix with a
* vector of column weights
*
* @tparam Type the data type
* @tparam IdxType Integer type used to for addressing
* @param mu the output mean vector
* @param data the input matrix (assumed to be row-major)
* @param weights per-column means
* @param D number of columns of data
* @param N number of rows of data
* @param stream cuda stream to launch work on
*/
template <typename Type>
template <typename Type, typename IdxType = int>
void rowWeightedMean(
Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream)
Type* mu, const Type* data, const Type* weights, IdxType D, IdxType N, cudaStream_t stream)
{
detail::rowWeightedMean(mu, data, weights, D, N, stream);
weightedMean(mu, data, weights, D, N, true, true, stream);
}

/**
* @brief Compute the column-wise weighted mean of the input matrix
* @brief Compute the column-wise weighted mean of the input matrix with a
* vector of row weights
*
* @tparam Type the data type
* @tparam IdxType Integer type used to for addressing
* @param mu the output mean vector
* @param data the input matrix (assumed to be column-major)
* @param weights per-column means
* @param data the input matrix (assumed to be row-major)
* @param weights per-row means
* @param D number of columns of data
* @param N number of rows of data
* @param stream cuda stream to launch work on
*/
template <typename Type>
template <typename Type, typename IdxType = int>
void colWeightedMean(
Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream)
Type* mu, const Type* data, const Type* weights, IdxType D, IdxType N, cudaStream_t stream)
{
detail::colWeightedMean(mu, data, weights, D, N, stream);
weightedMean(mu, data, weights, D, N, true, false, stream);
}
}; // end namespace stats
}; // end namespace raft
Expand Down
52 changes: 42 additions & 10 deletions cpp/include/raft/stats/weighted_mean.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,39 +29,71 @@ namespace raft {
namespace stats {

/**
* @brief Compute the row-wise weighted mean of the input matrix
* @brief Compute the weighted mean of the input matrix with a
* vector of weights, along rows or along columns
*
* @tparam Type the data type
* @tparam IdxType Integer type used to for addressing
* @param mu the output mean vector
* @param data the input matrix
* @param weights weight of size D if along_row is true, else of size N
* @param D number of columns of data
* @param N number of rows of data
* @param row_major data input matrix is row-major or not
* @param along_rows whether to reduce along rows or columns
* @param stream cuda stream to launch work on
*/
template <typename Type, typename IdxType = int>
void weightedMean(Type* mu,
const Type* data,
const Type* weights,
IdxType D,
IdxType N,
bool row_major,
bool along_rows,
cudaStream_t stream)
{
detail::weightedMean(mu, data, weights, D, N, row_major, along_rows, stream);
}

/**
* @brief Compute the row-wise weighted mean of the input matrix with a
* vector of column weights
*
* @tparam Type the data type
* @tparam IdxType Integer type used to for addressing
* @param mu the output mean vector
* @param data the input matrix (assumed to be row-major)
* @param weights per-column means
* @param D number of columns of data
* @param N number of rows of data
* @param stream cuda stream to launch work on
*/
template <typename Type>
template <typename Type, typename IdxType = int>
void rowWeightedMean(
Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream)
Type* mu, const Type* data, const Type* weights, IdxType D, IdxType N, cudaStream_t stream)
{
detail::rowWeightedMean(mu, data, weights, D, N, stream);
weightedMean(mu, data, weights, D, N, true, true, stream);
}

/**
* @brief Compute the column-wise weighted mean of the input matrix
* @brief Compute the column-wise weighted mean of the input matrix with a
* vector of row weights
*
* @tparam Type the data type
* @tparam IdxType Integer type used to for addressing
* @param mu the output mean vector
* @param data the input matrix (assumed to be column-major)
* @param weights per-column means
* @param data the input matrix (assumed to be row-major)
* @param weights per-row means
* @param D number of columns of data
* @param N number of rows of data
* @param stream cuda stream to launch work on
*/
template <typename Type>
template <typename Type, typename IdxType = int>
void colWeightedMean(
Type* mu, const Type* data, const Type* weights, int D, int N, cudaStream_t stream)
Type* mu, const Type* data, const Type* weights, IdxType D, IdxType N, cudaStream_t stream)
{
detail::colWeightedMean(mu, data, weights, D, N, stream);
weightedMean(mu, data, weights, D, N, true, false, stream);
}
}; // end namespace stats
}; // end namespace raft
Expand Down
Loading

0 comments on commit a6f3caf

Please sign in to comment.