Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Covariance, Pearson correlation for sort groupby (libcudf) #9154

Merged
merged 81 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from 79 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
763c53a
add CORR aggregation to groupby, headers, classes, visitor(sort)
karthikeyann Aug 31, 2021
4c989a9
add group_corr.cu
karthikeyann Aug 31, 2021
015795c
add unit test temporarily
karthikeyann Aug 31, 2021
ba6e50a
create new PR for pearson groupby correlation
skirui-source Sep 2, 2021
b198a51
adding corr. func in python
skirui-source Sep 2, 2021
b7464a2
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 3, 2021
3d00307
Revert "create new PR for pearson groupby correlation"
karthikeyann Sep 6, 2021
1200437
Revert "adding corr. func in python"
karthikeyann Sep 6, 2021
178f28a
Merge branch 'branch-21.10' of github.com:rapidsai/cudf into fea-sort…
karthikeyann Sep 6, 2021
60293cc
rename CORR to CORRELATION, added correlation_type as arg
karthikeyann Sep 6, 2021
d421d6d
add shallow_hash(column_view)
karthikeyann Sep 7, 2021
9c4a9f3
add CompoundTypes to type_lists
karthikeyann Sep 7, 2021
a3dd235
add shallow_hash tests
karthikeyann Sep 7, 2021
2365d07
add column copy test
karthikeyann Sep 7, 2021
88726a4
add shallow_equal(column_view) and tests
karthikeyann Sep 7, 2021
d52509d
update result_cache to use shallow_hash, shallow_equal
karthikeyann Sep 8, 2021
d9a8bd7
Update cpp/include/cudf/column/column_view.hpp
karthikeyann Sep 8, 2021
d96f870
added definition of correlation() in cython
skirui-source Sep 8, 2021
b6b92df
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 9, 2021
e3d6877
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 9, 2021
7e7f250
ignore data, nullmask, offset if parent size is empty
karthikeyann Sep 13, 2021
d1d5c3c
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 13, 2021
522e5a3
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 13, 2021
0005154
is_shallow_equal ignore children states for empty column. (not childr…
karthikeyann Sep 13, 2021
002b777
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 14, 2021
82b5a26
set STRUCT_AGGS to CORRELATION
skirui-source Sep 14, 2021
e692053
for empty column, ignore child pointers in shallow_hash
karthikeyann Sep 14, 2021
44372bc
rename is_shallow_equal to is_shallow_equivalent
karthikeyann Sep 14, 2021
d52fd53
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 15, 2021
d3e0053
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 15, 2021
e32935e
Merge branch 'branch-21.10' of https://github.com/rapidsai/cudf into …
skirui-source Sep 15, 2021
3aab04f
added ctypedef correlation_type. need to add tests
skirui-source Sep 15, 2021
ecc3a7d
use hash_combine for shallow hash
karthikeyann Sep 16, 2021
d2cd468
Apply suggestions from code review (jake)
karthikeyann Sep 16, 2021
fa40847
address review comments
karthikeyann Sep 17, 2021
f709b2a
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 17, 2021
6ac5725
update after PR #9185 updates
karthikeyann Sep 17, 2021
e863bc7
Merge branch 'branch-21.10' of github.com:rapidsai/cudf into enh-grou…
karthikeyann Sep 17, 2021
f66fdd9
Merge branch 'branch-21.10' of github.com:rapidsai/cudf into fea-shal…
karthikeyann Sep 18, 2021
e36b834
add boost license for hash_combine, move to diff header
karthikeyann Sep 18, 2021
a1ff894
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 18, 2021
1fbe3fc
Apply suggestions from code review (jake)
karthikeyann Sep 18, 2021
79ca5e5
Merge branches 'enh-groupby_cache_hashed' and 'fea-shallow_hash_colum…
karthikeyann Sep 18, 2021
fc3cc6b
include cleanup
karthikeyann Sep 18, 2021
eb2b0db
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 18, 2021
0593955
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 18, 2021
f7b6bb6
add missing include due to reorg
karthikeyann Sep 18, 2021
5b269ef
Merge branch 'enh-groupby_cache_hashed' of github.com:karthikeyann/cu…
karthikeyann Sep 18, 2021
7db9870
update groupby corr to use hashed result cache
karthikeyann Sep 18, 2021
5bb1dc4
Revert "set STRUCT_AGGS to CORRELATION"
karthikeyann Sep 18, 2021
fb98fd5
Revert "added ctypedef correlation_type. need to add tests"
karthikeyann Sep 18, 2021
324c37d
Revert "added definition of correlation() in cython"
karthikeyann Sep 18, 2021
9f19ddf
Apply suggestions from code review (jake)
karthikeyann Sep 20, 2021
ab955bb
enable result caching of child columns in correlation
karthikeyann Sep 20, 2021
98bbc94
fix duplicate {col, agg} request extract
karthikeyann Sep 20, 2021
e750cae
Merge branch 'enh-groupby_cache_hashed' of github.com:karthikeyann/cu…
karthikeyann Sep 20, 2021
9581525
address review comments
karthikeyann Sep 20, 2021
dcb0668
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 20, 2021
243490b
Merge branch 'branch-21.10' of github.com:rapidsai/cudf into enh-grou…
karthikeyann Sep 20, 2021
8d71146
Merge branch 'enh-groupby_cache_hashed' of github.com:karthikeyann/cu…
karthikeyann Sep 20, 2021
1a5f367
Update cpp/src/column/column_view.cpp
karthikeyann Sep 21, 2021
8fa765c
Merge branch 'fea-shallow_hash_columnview' of github.com:karthikeyann…
karthikeyann Sep 22, 2021
3e41c64
Merge branch 'enh-groupby_cache_hashed' of github.com:karthikeyann/cu…
karthikeyann Sep 22, 2021
63af02d
add groupby correlation tests
karthikeyann Sep 24, 2021
14dd5bb
enable dict for sort groupby mean
karthikeyann Sep 24, 2021
b0fea02
update group_corr for null support
karthikeyann Sep 24, 2021
57db901
rename group_corr to group_correlation
karthikeyann Sep 24, 2021
0d1a91e
update doc
karthikeyann Sep 24, 2021
e10ca8c
Merge branch 'branch-21.12' of github.com:rapidsai/cudf into enh-grou…
karthikeyann Sep 24, 2021
f6a56ce
Merge branch 'enh-groupby_cache_hashed' of github.com:karthikeyann/cu…
karthikeyann Sep 24, 2021
6cd47bc
minor comment corrections
karthikeyann Sep 27, 2021
4c2611b
Merge branch 'branch-21.12' of github.com:rapidsai/cudf into fea-sort…
karthikeyann Sep 27, 2021
075ec73
add covariance, refactor correlation to use covariance
karthikeyann Sep 30, 2021
dad641e
Merge branch 'branch-21.12' of github.com:rapidsai/cudf into fea-sort…
karthikeyann Sep 30, 2021
38e9ddc
Merge branch 'branch-21.12' of github.com:rapidsai/cudf into fea-sort…
karthikeyann Sep 30, 2021
6e6459d
Merge branch 'branch-21.12' of github.com:rapidsai/cudf into fea-sort…
karthikeyann Oct 4, 2021
077a187
add more null cases for correlation tests
karthikeyann Oct 4, 2021
e3f47c1
add covariance tests
karthikeyann Oct 4, 2021
6703533
Merge branch 'branch-21.12' into fea-sortgroupby_corr
karthikeyann Oct 4, 2021
dd00e0d
Merge branch 'branch-21.12' into fea-sortgroupby_corr
karthikeyann Oct 5, 2021
8426f56
Apply suggestions from code review
karthikeyann Oct 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ add_library(cudf
src/groupby/sort/group_argmax.cu
src/groupby/sort/group_argmin.cu
src/groupby/sort/group_collect.cu
src/groupby/sort/group_correlation.cu
src/groupby/sort/group_count.cu
src/groupby/sort/group_m2.cu
src/groupby/sort/group_max.cu
Expand Down
25 changes: 24 additions & 1 deletion cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ class aggregation {
CUDA, ///< CUDA UDF based reduction
MERGE_LISTS, ///< merge multiple lists values into one list
MERGE_SETS, ///< merge multiple lists values into one list then drop duplicate entries
MERGE_M2, ///< merge partial values of M2 aggregation
MERGE_M2, ///< merge partial values of M2 aggregation,
COVARIANCE, ///< covariance between two sets of elements
CORRELATION, ///< correlation between two sets of elements
TDIGEST, ///< create a tdigest from a set of input values
MERGE_TDIGEST ///< create a tdigest by merging multiple tdigests together
};
Expand Down Expand Up @@ -146,6 +148,7 @@ class groupby_scan_aggregation : public virtual aggregation {
};

enum class udf_type : bool { CUDA, PTX };
enum class correlation_type : int32_t { PEARSON, KENDALL, SPEARMAN };

/// Factory to create a SUM aggregation
template <typename Base = aggregation>
Expand Down Expand Up @@ -495,6 +498,26 @@ std::unique_ptr<Base> make_merge_sets_aggregation(null_equality nulls_equal = nu
template <typename Base = aggregation>
std::unique_ptr<Base> make_merge_m2_aggregation();

/**
* @brief Factory to create a COVARIANCE aggregation
*
* Compute covariance between two columns.
* The input columns are child columns of a non-nullable struct columns.
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_covariance_aggregation();

/**
* @brief Factory to create a CORRELATION aggregation
*
* Compute correlation coefficient between two columns.
* The input columns are child columns of a non-nullable struct columns.
*
* @param[in] type: correlation_type
*/
template <typename Base = aggregation>
std::unique_ptr<Base> make_correlation_aggregation(correlation_type type);

/**
* @brief Factory to create a TDIGEST aggregation
*
Expand Down
73 changes: 73 additions & 0 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ class simple_aggregations_collector { // Declares the interface for the simple
class merge_sets_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class merge_m2_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class covariance_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class correlation_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(data_type col_type,
class tdigest_aggregation const& agg);
virtual std::vector<std::unique_ptr<aggregation>> visit(
Expand Down Expand Up @@ -129,6 +133,8 @@ class aggregation_finalizer { // Declares the interface for the finalizer
virtual void visit(class merge_lists_aggregation const& agg);
virtual void visit(class merge_sets_aggregation const& agg);
virtual void visit(class merge_m2_aggregation const& agg);
virtual void visit(class covariance_aggregation const& agg);
virtual void visit(class correlation_aggregation const& agg);
virtual void visit(class tdigest_aggregation const& agg);
virtual void visit(class merge_tdigest_aggregation const& agg);
};
Expand Down Expand Up @@ -890,6 +896,57 @@ class merge_m2_aggregation final : public groupby_aggregation {
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived aggregation class for specifying COVARIANCE aggregation
*/
class covariance_aggregation final : public groupby_aggregation {
public:
explicit covariance_aggregation() : aggregation{COVARIANCE} {}

std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<covariance_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }
};

/**
* @brief Derived aggregation class for specifying CORRELATION aggregation
*/
class correlation_aggregation final : public groupby_aggregation {
public:
explicit correlation_aggregation(correlation_type type) : aggregation{CORRELATION}, _type{type} {}
correlation_type _type;

bool is_equal(aggregation const& _other) const override
{
if (!this->aggregation::is_equal(_other)) { return false; }
auto const& other = dynamic_cast<correlation_aggregation const&>(_other);
return (_type == other._type);
}

size_t do_hash() const override { return this->aggregation::do_hash() ^ hash_impl(); }
nvdbaranec marked this conversation as resolved.
Show resolved Hide resolved

std::unique_ptr<aggregation> clone() const override
{
return std::make_unique<correlation_aggregation>(*this);
}
std::vector<std::unique_ptr<aggregation>> get_simple_aggregations(
data_type col_type, simple_aggregations_collector& collector) const override
{
return collector.visit(col_type, *this);
}
void finalize(aggregation_finalizer& finalizer) const override { finalizer.visit(*this); }

protected:
size_t hash_impl() const { return std::hash<int>{}(static_cast<int>(_type)); }
};

/**
* @brief Derived aggregation class for specifying TDIGEST aggregation
*/
Expand Down Expand Up @@ -1174,6 +1231,18 @@ struct target_type_impl<SourceType, aggregation::MERGE_M2> {
using type = struct_view;
};

// Always use double for COVARIANCE
template <typename SourceType>
struct target_type_impl<SourceType, aggregation::COVARIANCE> {
using type = double;
};

// Always use double for CORRELATION
template <typename SourceType>
struct target_type_impl<SourceType, aggregation::CORRELATION> {
using type = double;
};

// Always use numeric types for TDIGEST
template <typename Source>
struct target_type_impl<Source,
Expand Down Expand Up @@ -1296,6 +1365,10 @@ CUDA_HOST_DEVICE_CALLABLE decltype(auto) aggregation_dispatcher(aggregation::Kin
return f.template operator()<aggregation::MERGE_SETS>(std::forward<Ts>(args)...);
case aggregation::MERGE_M2:
return f.template operator()<aggregation::MERGE_M2>(std::forward<Ts>(args)...);
case aggregation::COVARIANCE:
return f.template operator()<aggregation::COVARIANCE>(std::forward<Ts>(args)...);
case aggregation::CORRELATION:
return f.template operator()<aggregation::CORRELATION>(std::forward<Ts>(args)...);
case aggregation::TDIGEST:
return f.template operator()<aggregation::TDIGEST>(std::forward<Ts>(args)...);
case aggregation::MERGE_TDIGEST:
Expand Down
32 changes: 19 additions & 13 deletions cpp/include/cudf/detail/aggregation/result_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,26 @@

#include <cudf/column/column.hpp>
#include <cudf/detail/aggregation/aggregation.hpp>
#include <cudf/detail/hashing.hpp>
#include <cudf/types.hpp>

#include <unordered_map>

namespace cudf {
namespace detail {
struct aggregation_equality {
bool operator()(aggregation const& lhs, aggregation const& rhs) const
struct pair_column_aggregation_equal_to {
bool operator()(std::pair<column_view, aggregation const&> const& lhs,
std::pair<column_view, aggregation const&> const& rhs) const
{
return lhs.is_equal(rhs);
return is_shallow_equivalent(lhs.first, rhs.first) and lhs.second.is_equal(rhs.second);
}
};

struct aggregation_hash {
size_t operator()(aggregation const& key) const noexcept { return key.do_hash(); }
struct pair_column_aggregation_hash {
size_t operator()(std::pair<column_view, aggregation const&> const& key) const noexcept
{
return hash_combine(shallow_hash(key.first), key.second.do_hash());
}
};

class result_cache {
Expand All @@ -43,19 +49,19 @@ class result_cache {

result_cache(size_t num_columns) : _cache(num_columns) {}

bool has_result(size_t col_idx, aggregation const& agg) const;
bool has_result(column_view const& input, aggregation const& agg) const;

void add_result(size_t col_idx, aggregation const& agg, std::unique_ptr<column>&& col);
void add_result(column_view const& input, aggregation const& agg, std::unique_ptr<column>&& col);

column_view get_result(size_t col_idx, aggregation const& agg) const;
column_view get_result(column_view const& input, aggregation const& agg) const;

std::unique_ptr<column> release_result(size_t col_idx, aggregation const& agg);
std::unique_ptr<column> release_result(column_view const& input, aggregation const& agg);

private:
std::vector<std::unordered_map<std::reference_wrapper<aggregation const>,
std::pair<std::unique_ptr<aggregation>, std::unique_ptr<column>>,
aggregation_hash,
aggregation_equality>>
std::unordered_map<std::pair<column_view, std::reference_wrapper<aggregation const>>,
std::pair<std::unique_ptr<aggregation>, std::unique_ptr<column>>,
pair_column_aggregation_hash,
pair_column_aggregation_equal_to>
_cache;
};

Expand Down
40 changes: 40 additions & 0 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,16 @@ std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
return visit(col_type, static_cast<aggregation const&>(agg));
}

std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, covariance_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}
std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, correlation_aggregation const& agg)
{
return visit(col_type, static_cast<aggregation const&>(agg));
}
std::vector<std::unique_ptr<aggregation>> simple_aggregations_collector::visit(
data_type col_type, tdigest_aggregation const& agg)
{
Expand Down Expand Up @@ -358,6 +368,16 @@ void aggregation_finalizer::visit(merge_m2_aggregation const& agg)
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(covariance_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(correlation_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
}

void aggregation_finalizer::visit(tdigest_aggregation const& agg)
{
visit(static_cast<aggregation const&>(agg));
Expand Down Expand Up @@ -690,6 +710,26 @@ std::unique_ptr<Base> make_merge_m2_aggregation()
template std::unique_ptr<aggregation> make_merge_m2_aggregation<aggregation>();
template std::unique_ptr<groupby_aggregation> make_merge_m2_aggregation<groupby_aggregation>();

/// Factory to create a COVARIANCE aggregation
template <typename Base>
std::unique_ptr<Base> make_covariance_aggregation()
{
return std::make_unique<detail::covariance_aggregation>();
}
template std::unique_ptr<aggregation> make_covariance_aggregation<aggregation>();
template std::unique_ptr<groupby_aggregation> make_covariance_aggregation<groupby_aggregation>();

/// Factory to create a CORRELATION aggregation
template <typename Base>
std::unique_ptr<Base> make_correlation_aggregation(correlation_type type)
{
return std::make_unique<detail::correlation_aggregation>(type);
}
template std::unique_ptr<aggregation> make_correlation_aggregation<aggregation>(
correlation_type type);
template std::unique_ptr<groupby_aggregation> make_correlation_aggregation<groupby_aggregation>(
correlation_type type);

template <typename Base>
std::unique_ptr<Base> make_tdigest_aggregation(int max_centroids)
{
Expand Down
37 changes: 17 additions & 20 deletions cpp/src/aggregation/result_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,36 @@
namespace cudf {
namespace detail {

bool result_cache::has_result(size_t col_idx, aggregation const& agg) const
bool result_cache::has_result(column_view const& input, aggregation const& agg) const
{
if (col_idx > _cache.size()) return false;

auto result_it = _cache[col_idx].find(agg);

return (result_it != _cache[col_idx].end());
return _cache.count({input, agg});
}

void result_cache::add_result(size_t col_idx, aggregation const& agg, std::unique_ptr<column>&& col)
void result_cache::add_result(column_view const& input,
aggregation const& agg,
std::unique_ptr<column>&& col)
{
// We can't guarantee that agg will outlive the cache, so we need to take ownership of a copy.
// To allow lookup by reference, make the key a reference and keep the owner in the value pair.
auto owned_agg = agg.clone();
auto const& key = *owned_agg;
auto value = std::make_pair(std::move(owned_agg), std::move(col));
_cache[col_idx].emplace(key, std::move(value));
auto owned_agg = agg.clone();
auto const& key = *owned_agg;
auto value = std::make_pair(std::move(owned_agg), std::move(col));
_cache[{input, key}] = std::move(value);
}

column_view result_cache::get_result(size_t col_idx, aggregation const& agg) const
column_view result_cache::get_result(column_view const& input, aggregation const& agg) const
{
CUDF_EXPECTS(has_result(col_idx, agg), "Result does not exist in cache");

auto result_it = _cache[col_idx].find(agg);
auto result_it = _cache.find({input, agg});
CUDF_EXPECTS(result_it != _cache.end(), "Result does not exist in cache");
return result_it->second.second->view();
}

std::unique_ptr<column> result_cache::release_result(size_t col_idx, aggregation const& agg)
std::unique_ptr<column> result_cache::release_result(column_view const& input,
aggregation const& agg)
{
CUDF_EXPECTS(has_result(col_idx, agg), "Result does not exist in cache");

auto result_it = _cache[col_idx].extract(agg);
return std::move(result_it.mapped().second);
auto node = _cache.extract({input, agg});
CUDF_EXPECTS(not node.empty(), "Result does not exist in cache");
return std::move(node.mapped().second);
}

} // namespace detail
Expand Down
20 changes: 18 additions & 2 deletions cpp/src/groupby/common/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <cudf/detail/aggregation/result_cache.hpp>
#include <cudf/detail/groupby.hpp>
#include <cudf/utilities/span.hpp>

#include <memory>
#include <vector>

namespace cudf {
Expand All @@ -30,10 +32,24 @@ inline std::vector<aggregation_result> extract_results(host_span<RequestType con
cudf::detail::result_cache& cache)
{
std::vector<aggregation_result> results(requests.size());

std::unordered_map<std::pair<column_view, std::reference_wrapper<aggregation const>>,
column_view,
cudf::detail::pair_column_aggregation_hash,
cudf::detail::pair_column_aggregation_equal_to>
repeated_result;
for (size_t i = 0; i < requests.size(); i++) {
for (auto&& agg : requests[i].aggregations) {
results[i].results.emplace_back(cache.release_result(i, *agg));
if (cache.has_result(requests[i].values, *agg)) {
results[i].results.emplace_back(cache.release_result(requests[i].values, *agg));
repeated_result[{requests[i].values, *agg}] = results[i].results.back()->view();
} else {
auto it = repeated_result.find({requests[i].values, *agg});
if (it != repeated_result.end()) {
results[i].results.emplace_back(std::make_unique<column>(it->second));
} else {
CUDF_FAIL("Cannot extract result from the cache");
}
}
}
}
return results;
Expand Down
Loading