Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Covariance, Pearson correlation for sort groupby (libcudf) #9154

Merged
merged 81 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from 80 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
763c53a
add CORR aggregation to groupby, headers, classes, visitor(sort)
karthikeyann Aug 31, 2021
4c989a9
add group_corr.cu
karthikeyann Aug 31, 2021
015795c
add unit test temporarily
karthikeyann Aug 31, 2021
ba6e50a
create new PR for pearson groupby correlation
skirui-source Sep 2, 2021
b198a51
adding corr. func in python
skirui-source Sep 2, 2021
b7464a2
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 3, 2021
3d00307
Revert "create new PR for pearson groupby correlation"
karthikeyann Sep 6, 2021
1200437
Revert "adding corr. func in python"
karthikeyann Sep 6, 2021
178f28a
Merge branch 'branch-21.10' of github.com:rapidsai/cudf into fea-sort…
karthikeyann Sep 6, 2021
60293cc
rename CORR to CORRELATION, added correlation_type as arg
karthikeyann Sep 6, 2021
d421d6d
add shallow_hash(column_view)
karthikeyann Sep 7, 2021
9c4a9f3
add CompoundTypes to type_lists
karthikeyann Sep 7, 2021
a3dd235
add shallow_hash tests
karthikeyann Sep 7, 2021
2365d07
add column copy test
karthikeyann Sep 7, 2021
88726a4
add shallow_equal(column_view) and tests
karthikeyann Sep 7, 2021
d52509d
update result_cache to use shallow_hash, shallow_equal
karthikeyann Sep 8, 2021
d9a8bd7
Update cpp/include/cudf/column/column_view.hpp
karthikeyann Sep 8, 2021
d96f870
added definition of correlation() in cython
skirui-source Sep 8, 2021
b6b92df
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 9, 2021
e3d6877
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 9, 2021
7e7f250
ignore data, nullmask, offset if parent size is empty
karthikeyann Sep 13, 2021
d1d5c3c
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 13, 2021
522e5a3
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 13, 2021
0005154
is_shallow_equal ignore children states for empty column. (not childr…
karthikeyann Sep 13, 2021
002b777
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 14, 2021
82b5a26
set STRUCT_AGGS to CORRELATION
skirui-source Sep 14, 2021
e692053
for empty column, ignore child pointers in shallow_hash
karthikeyann Sep 14, 2021
44372bc
rename is_shallow_equal to is_shallow_equivalent
karthikeyann Sep 14, 2021
d52fd53
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 15, 2021
d3e0053
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 15, 2021
e32935e
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 15, 2021
3aab04f
added ctypedef correlation_type. need to add tests
skirui-source Sep 15, 2021
ecc3a7d
use hash_combine for shallow hash
karthikeyann Sep 16, 2021
d2cd468
Apply suggestions from code review (jake)
karthikeyann Sep 16, 2021
fa40847
address review comments
karthikeyann Sep 17, 2021
f709b2a
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 17, 2021
6ac5725
update after PR #9185 updates
karthikeyann Sep 17, 2021
e863bc7
Merge branch 'branch-21.10' of github.com:rapidsai/cudf into enh-grou…
karthikeyann Sep 17, 2021
f66fdd9
Merge branch 'branch-21.10' of github.com:rapidsai/cudf into fea-shal…
karthikeyann Sep 18, 2021
e36b834
add boost license for hash_combine, move to diff header
karthikeyann Sep 18, 2021
a1ff894
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 18, 2021
1fbe3fc
Apply suggestions from code review (jake)
karthikeyann Sep 18, 2021
79ca5e5
Merge branches 'enh-groupby_cache_hashed' and 'fea-shallow_hash_colum…
karthikeyann Sep 18, 2021
fc3cc6b
include cleanup
karthikeyann Sep 18, 2021
eb2b0db
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 18, 2021
0593955
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 18, 2021
f7b6bb6
add missing include due to reorg
karthikeyann Sep 18, 2021
5b269ef
Merge branch 'enh-groupby_cache_hashed' of github.com:karthikeyann/cu…
karthikeyann Sep 18, 2021
7db9870
update groupby corr to use hashed result cache
karthikeyann Sep 18, 2021
5bb1dc4
Revert "set STRUCT_AGGS to CORRELATION"
karthikeyann Sep 18, 2021
fb98fd5
Revert "added ctypedef correlation_type. need to add tests"
karthikeyann Sep 18, 2021
324c37d
Revert "added definition of correlation() in cython"
karthikeyann Sep 18, 2021
9f19ddf
Apply suggestions from code review (jake)
karthikeyann Sep 20, 2021
ab955bb
enable result caching of child columns in correlation
karthikeyann Sep 20, 2021
98bbc94
fix duplicate {col, agg} request extract
karthikeyann Sep 20, 2021
e750cae
Merge branch 'enh-groupby_cache_hashed' of github.com:karthikeyann/cu…
karthikeyann Sep 20, 2021
9581525
address review comments
karthikeyann Sep 20, 2021
dcb0668
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 20, 2021
243490b
Merge branch 'branch-21.10' of github.com:rapidsai/cudf into enh-grou…
karthikeyann Sep 20, 2021
8d71146
Merge branch 'enh-groupby_cache_hashed' of github.com:karthikeyann/cu…
karthikeyann Sep 20, 2021
1a5f367
Update cpp/src/column/column_view.cpp
karthikeyann Sep 21, 2021
8fa765c
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 22, 2021
3e41c64
Merge branch 'enh-groupby_cache_hashed' of github.com:karthikeyann/cu…
karthikeyann Sep 22, 2021
63af02d
add groupby correlation tests
karthikeyann Sep 24, 2021
14dd5bb
enable dict for sort groupby mean
karthikeyann Sep 24, 2021
b0fea02
update group_corr for null support
karthikeyann Sep 24, 2021
57db901
rename group_corr to group_correlation
karthikeyann Sep 24, 2021
0d1a91e
update doc
karthikeyann Sep 24, 2021
e10ca8c
Merge branch 'branch-21.12' of github.com:rapidsai/cudf into enh-grou…
karthikeyann Sep 24, 2021
f6a56ce
Merge branch 'enh-groupby_cache_hashed' of github.com:karthikeyann/cu…
karthikeyann Sep 24, 2021
6cd47bc
minor comment corrections
karthikeyann Sep 27, 2021
4c2611b
Merge branch 'branch-21.12' of github.com:rapidsai/cudf into fea-sort…
karthikeyann Sep 27, 2021
075ec73
add covariance, refactor correlation to use covariance
karthikeyann Sep 30, 2021
dad641e
Merge branch 'branch-21.12' of github.com:rapidsai/cudf into fea-sort…
karthikeyann Sep 30, 2021
38e9ddc
Merge branch 'branch-21.12' of github.com:rapidsai/cudf into fea-sort…
karthikeyann Sep 30, 2021
6e6459d
Merge branch 'branch-21.12' of github.com:rapidsai/cudf into fea-sort…
karthikeyann Oct 4, 2021
077a187
add more null cases for correlation tests
karthikeyann Oct 4, 2021
e3f47c1
add covariance tests
karthikeyann Oct 4, 2021
6703533
Merge branch 'branch-21.12' into fea-sortgroupby_corr
karthikeyann Oct 4, 2021
dd00e0d
Merge branch 'branch-21.12' into fea-sortgroupby_corr
karthikeyann Oct 5, 2021
8426f56
Apply suggestions from code review
karthikeyann Oct 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ add_library(cudf
src/groupby/sort/group_argmax.cu
src/groupby/sort/group_argmin.cu
src/groupby/sort/group_collect.cu
src/groupby/sort/group_correlation.cu
src/groupby/sort/group_count.cu
src/groupby/sort/group_m2.cu
src/groupby/sort/group_max.cu
Expand Down
25 changes: 24 additions & 1 deletion cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ class aggregation {
CUDA, ///< CUDA UDF based reduction
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
MERGE_M2, ///< merge partial values of M2 aggregation
MERGE_M2, ///< merge partial values of M2 aggregation,
COVARIANCE, ///< covariance between two sets of elements
CORRELATION, ///< correlation between two sets of elements
TDIGEST, ///< create a tdigest from a set of input values
MERGE_TDIGEST ///< create a tdigest by merging multiple tdigests together
};
Expand Down Expand Up @@ -146,6 +148,7 @@ class groupby_scan_aggregation : public virtual aggregation {
};

enum class udf_type : bool { CUDA, PTX };
enum class correlation_type : int32_t { PEARSON, KENDALL, SPEARMAN };

/// Factory to create a SUM aggregation
template <typename Base = aggregation>
Expand Down Expand Up @@ -495,6 +498,26 @@ std::unique_ptr<Base> make_merge_sets_aggregation(null_equality nulls_equal = nu
template <typename Base = aggregation>
std::unique_ptr<Base> make_merge_m2_aggregation();

/**
* @brief Factory to create a COVARIANCE aggregation
*
* Compute covariance between two columns.
* The input columns are child columns of a non-nullable struct columns.
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_covariance_aggregation();

/**
* @brief Factory to create a CORRELATION aggregation
*
* Compute correlation coefficient between two columns.
* The input columns are child columns of a non-nullable struct columns.
*
* @param[in] type: correlation_type
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_correlation_aggregation(correlation_type type);

/**
* @brief Factory to create a TDIGEST aggregation
*
Expand Down
73 changes: 73 additions & 0 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ class simple_aggregations_collector { // Declares the interface for the simple
class merge_sets_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class merge_m2_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class covariance_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class correlation_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class tdigest_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(
Expand Down Expand Up @@ -129,6 +133,8 @@ class aggregation_finalizer { // Declares the interface for the finalizer
virtual void visit(class merge_lists_aggregation const& agg);
virtual void visit(class merge_sets_aggregation const& agg);
virtual void visit(class merge_m2_aggregation const& agg);
virtual void visit(class covariance_aggregation const& agg);
virtual void visit(class correlation_aggregation const& agg);
virtual void visit(class tdigest_aggregation const& agg);
virtual void visit(class merge_tdigest_aggregation const& agg);
};
Expand Down Expand Up @@ -890,6 +896,57 @@ class merge_m2_aggregation final : public groupby_aggregation {
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived aggregation class for specifying COVARIANCE aggregation
*/
class covariance_aggregation final : public groupby_aggregation {
public:
explicit covariance_aggregation() : aggregation{COVARIANCE} {}

std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<covariance_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived aggregation class for specifying CORRELATION aggregation
*/
class correlation_aggregation final : public groupby_aggregation {
public:
explicit correlation_aggregation(correlation_type type) : aggregation{CORRELATION}, _type{type} {}
correlation_type _type;

bool is_equal(aggregation const& _other) const override
{
if (!this->aggregation::is_equal(_other)) { return false; }
auto const& other = dynamic_cast<correlation_aggregation const&>(_other);
return (_type == other._type);
}

size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); }
nvdbaranec marked this conversation as resolved.
Show resolved Hide resolved

std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<correlation_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }

protected:
size_t hash_impl() const { return std::hash<int>{}(static_cast<int>(_type)); }
};

/**
* @brief Derived aggregation class for specifying TDIGEST aggregation
*/
Expand Down Expand Up @@ -1174,6 +1231,18 @@ struct target_type_impl<SourceType, aggregation::MERGE_M2> {
using type = struct_view;
};

// Always use double for COVARIANCE
template <typename SourceType>
struct target_type_impl<SourceType, aggregation::COVARIANCE> {
using type = double;
};

// Always use double for CORRELATION
template <typename SourceType>
struct target_type_impl<SourceType, aggregation::CORRELATION> {
using type = double;
};

// Always use numeric types for TDIGEST
template <typename Source>
struct target_type_impl<Source,
Expand Down Expand Up @@ -1296,6 +1365,10 @@ CUDA_HOST_DEVICE_CALLABLE decltype(auto) aggregation_dispatcher(aggregation::Kin
return f.template operator()<aggregation::MERGE_SETS>(std::forward<Ts>(args)...);
case aggregation::MERGE_M2:
return f.template operator()<aggregation::MERGE_M2>(std::forward<Ts>(args)...);
case aggregation::COVARIANCE:
return f.template operator()<aggregation::COVARIANCE>(std::forward<Ts>(args)...);
case aggregation::CORRELATION:
return f.template operator()<aggregation::CORRELATION>(std::forward<Ts>(args)...);
case aggregation::TDIGEST:
return f.template operator()<aggregation::TDIGEST>(std::forward<Ts>(args)...);
case aggregation::MERGE_TDIGEST:
Expand Down
40 changes: 40 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,16 @@ std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, covariance_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}
std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, correlation_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}
std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, tdigest_aggregation const& agg)
{
Expand Down Expand Up @@ -358,6 +368,16 @@ void aggregation_finalizer::visit(merge_m2_aggregation const& agg)
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(covariance_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(correlation_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(tdigest_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
Expand Down Expand Up @@ -691,6 +711,26 @@ std::unique_ptr<Base> make_merge_m2_aggregation()
template std::unique_ptr<aggregation> make_merge_m2_aggregation<aggregation>();
template std::unique_ptr<groupby_aggregation> make_merge_m2_aggregation<groupby_aggregation>();

/// Factory to create a COVARIANCE aggregation
template <typename Base>
std::unique_ptr<Base> make_covariance_aggregation()
{
return std::make_unique<detail::covariance_aggregation>();
}
template std::unique_ptr<aggregation> make_covariance_aggregation<aggregation>();
template std::unique_ptr<groupby_aggregation> make_covariance_aggregation<groupby_aggregation>();

/// Factory to create a CORRELATION aggregation
template <typename Base>
std::unique_ptr<Base> make_correlation_aggregation(correlation_type type)
{
return std::make_unique<detail::correlation_aggregation>(type);
}
template std::unique_ptr<aggregation> make_correlation_aggregation<aggregation>(
correlation_type type);
template std::unique_ptr<groupby_aggregation> make_correlation_aggregation<groupby_aggregation>(
correlation_type type);

template <typename Base>
std::unique_ptr<Base> make_tdigest_aggregation(int max_centroids)
{
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/groupby/hash/groupby.cu
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ class groupby_simple_aggregations_collector final

return aggs;
}

std::vector<std::unique_ptr<aggregation>> visit(
data_type, cudf::detail::correlation_aggregation const&) override
{
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(make_sum_aggregation());
// COUNT_VALID
aggs.push_back(make_count_aggregation());

return aggs;
Comment on lines +165 to +170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to be fancy.

Suggested change
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(make_sum_aggregation());
// COUNT_VALID
aggs.push_back(make_count_aggregation());
return aggs;
return {make_sum_aggregation(), make_count_aggregation()};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::initializer_list only allows access to const elements. When unique_ptr is given in initializer list, std::vector tries to copy unique_ptr. copy constructor of unique_ptr is deleted.
So, this fails to compile.

Copy link
Contributor

@jrhemstad jrhemstad Oct 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sure, that makes sense. This should work though?

Suggested change
std::vector<std::unique_ptr<aggregation>> aggs;
aggs.push_back(make_sum_aggregation());
// COUNT_VALID
aggs.push_back(make_count_aggregation());
return aggs;
return std::vector({make_sum_aggregation(), make_count_aggregation()});

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has same effect as above.
Anything involving {} initializer_list on a object cannot be moved. (probably const_cast inside constructor might work, but that's not useful here.)

https://stackoverflow.com/a/8469002/1550940
Surprisingly array works.

     std::unique_ptr<aggregation> init[] = {make_sum_aggregation(), make_count_aggregation()};
     return std::vector<std::unique_ptr<aggregation>>{std::make_move_iterator(std::begin(init)), std::make_move_iterator(std::end(init))};

}
};

template <typename Map>
Expand Down
133 changes: 132 additions & 1 deletion cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cudf/detail/binaryop.hpp>
#include <cudf/detail/gather.hpp>
#include <cudf/detail/groupby/sort_helper.hpp>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/unary.hpp>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/groupby.hpp>
Expand Down Expand Up @@ -235,11 +236,14 @@ void aggregate_result_functor::operator()<aggregation::MEAN>(aggregation const&

// TODO (dm): Special case for timestamp. Add target_type_impl for it.
// Blocked until we support operator+ on timestamps
auto col_type = cudf::is_dictionary(values.type())
? cudf::dictionary_column_view(values).keys().type()
: values.type();
auto result =
cudf::detail::binary_operation(sum_result,
count_result,
binary_operator::DIV,
cudf::detail::target_type(values.type(), aggregation::MEAN),
cudf::detail::target_type(col_type, aggregation::MEAN),
stream,
mr);
cache.add_result(values, agg, std::move(result));
Expand Down Expand Up @@ -525,6 +529,133 @@ void aggregate_result_functor::operator()<aggregation::MERGE_M2>(aggregation con
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr));
};

/**
* @brief Creates column views with only valid elements in both input column views
*
* @param column_0 The first column
* @param column_1 The second column
* @return tuple with new null mask (if null masks if input differ) and new column views
karthikeyann marked this conversation as resolved.
Show resolved Hide resolved
*/
auto column_view_with_common_nulls(column_view const& column_0, column_view const& column_1)
{
rmm::device_buffer new_nullmask = cudf::bitmask_and(table_view{{column_0, column_1}});
auto null_count = cudf::count_unset_bits(
static_cast<cudf::bitmask_type const*>(new_nullmask.data()), 0, column_0.size());
if (null_count == 0) { return std::make_tuple(std::move(new_nullmask), column_0, column_1); }
auto column_view_with_new_nullmask = [](auto const& col, void* nullmask, auto null_count) {
return column_view(col.type(),
col.size(),
col.head(),
static_cast<cudf::bitmask_type const*>(nullmask),
null_count,
col.offset(),
std::vector(col.child_begin(), col.child_end()));
};
auto new_column_0 = null_count == column_0.null_count()
? column_0
: column_view_with_new_nullmask(column_0, new_nullmask.data(), null_count);
auto new_column_1 = null_count == column_1.null_count()
? column_1
: column_view_with_new_nullmask(column_1, new_nullmask.data(), null_count);
return std::make_tuple(std::move(new_nullmask), new_column_0, new_column_1);
}

/**
* @brief Perform covariance betweeen two child columns of non-nullable struct column.
*
*/
template <>
void aggregate_result_functor::operator()<aggregation::COVARIANCE>(aggregation const& agg)
{
if (cache.has_result(values, agg)) { return; }
CUDF_EXPECTS(values.type().id() == type_id::STRUCT,
"Input to `groupby covariance` must be a structs column.");
CUDF_EXPECTS(values.num_children() == 2,
"Input to `groupby covariance` must be a structs column having 2 children columns.");

// Covariance only for valid values in both columns.
// in non-identical null mask cases, this prevents caching of the results - STD, MEAN, COUNT.
auto [_, values_child0, values_child1] =
column_view_with_common_nulls(values.child(0), values.child(1));

auto mean_agg = make_mean_aggregation();
aggregate_result_functor(values_child0, helper, cache, stream, mr).operator()<aggregation::MEAN>(*mean_agg);
aggregate_result_functor(values_child1, helper, cache, stream, mr).operator()<aggregation::MEAN>(*mean_agg);

auto const mean0 = cache.get_result(values_child0, *mean_agg);
auto const mean1 = cache.get_result(values_child1, *mean_agg);
auto count_agg = make_count_aggregation();
auto const count = cache.get_result(values_child0, *count_agg);

cache.add_result(values,
agg,
detail::group_covariance(get_grouped_values().child(0),
get_grouped_values().child(1),
helper.group_labels(stream),
helper.num_groups(stream),
count,
mean0,
mean1,
stream,
mr));
};

/**
* @brief Perform correlation betweeen two child columns of non-nullable struct column.
*
*/
template <>
void aggregate_result_functor::operator()<aggregation::CORRELATION>(aggregation const& agg)
{
if (cache.has_result(values, agg)) { return; }
CUDF_EXPECTS(values.type().id() == type_id::STRUCT,
"Input to `groupby correlation` must be a structs column.");
CUDF_EXPECTS(
values.num_children() == 2,
"Input to `groupby correlation` must be a structs column having 2 children columns.");
CUDF_EXPECTS(values.nullable() == false,
"Input to `groupby correlation` must be a non-nullable structs column.");

auto const& corr_agg = dynamic_cast<cudf::detail::correlation_aggregation const&>(agg);
CUDF_EXPECTS(corr_agg._type == correlation_type::PEARSON,
"Only Pearson correlation is supported.");

// Correlation only for valid values in both columns.
// in non-identical null mask cases, this prevents caching of the results - STD, MEAN, COUNT
auto [_, values_child0, values_child1] =
column_view_with_common_nulls(values.child(0), values.child(1));

auto std_agg = make_std_aggregation();
aggregate_result_functor(values_child0, helper, cache, stream, mr).operator()<aggregation::STD>(*std_agg);
aggregate_result_functor(values_child1, helper, cache, stream, mr).operator()<aggregation::STD>(*std_agg);

auto const stddev0 = cache.get_result(values_child0, *std_agg);
auto const stddev1 = cache.get_result(values_child1, *std_agg);

auto mean_agg = make_mean_aggregation();
auto const mean0 = cache.get_result(values_child0, *mean_agg);
auto const mean1 = cache.get_result(values_child1, *mean_agg);
auto count_agg = make_count_aggregation();
auto const count = cache.get_result(values_child0, *count_agg);

// Compute covariance here to avoid repeated computation of mean & count
auto cov_agg = make_covariance_aggregation();
cache.add_result(values,
*cov_agg,
detail::group_covariance(get_grouped_values().child(0),
get_grouped_values().child(1),
helper.group_labels(stream),
helper.num_groups(stream),
count,
mean0,
mean1,
stream,
mr));
auto const covariance = cache.get_result(values, *cov_agg);
cache.add_result(
values, agg, detail::group_correlation(covariance, stddev0, stddev1, stream, mr));
}

/**
* @brief Generate a tdigest column from a grouped set of numeric input values.
*
Expand Down
Loading