Skip to content

Commit

Permalink
Add enforcement that the output_dtype parameter passed to reduce for …
Browse files Browse the repository at this point in the history
…tdigest aggregations is STRUCT.
  • Loading branch information
nvdbaranec committed Mar 16, 2022
1 parent 98e76ef commit eb6a0c4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
4 changes: 4 additions & 0 deletions cpp/src/reductions/reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,14 @@ struct reduce_dispatch_functor {
return reduction::merge_sets(col, col_agg->_nulls_equal, col_agg->_nans_equal, stream, mr);
} break;
case aggregation::TDIGEST: {
CUDF_EXPECTS(output_dtype.id() == type_id::STRUCT,
"Tdigest aggregations expect output type to be STRUCT");
auto td_agg = dynamic_cast<tdigest_aggregation const*>(agg.get());
return detail::tdigest::reduce_tdigest(col, td_agg->max_centroids, stream, mr);
} break;
case aggregation::MERGE_TDIGEST: {
CUDF_EXPECTS(output_dtype.id() == type_id::STRUCT,
"Tdigest aggregations expect output type to be STRUCT");
auto td_agg = dynamic_cast<merge_tdigest_aggregation const*>(agg.get());
return detail::tdigest::reduce_merge_tdigest(col, td_agg->max_centroids, stream, mr);
} break;
Expand Down
2 changes: 1 addition & 1 deletion cpp/tests/quantiles/percentile_approx_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ void percentile_approx_test(column_view const& _keys,
auto scalar_result =
cudf::reduce(values,
cudf::make_tdigest_aggregation<cudf::reduce_aggregation>(delta),
data_type{type_id::FLOAT64});
data_type{type_id::STRUCT});
auto tbl = static_cast<cudf::struct_scalar const*>(scalar_result.get())->view();
std::vector<std::unique_ptr<cudf::column>> cols;
std::transform(
Expand Down
4 changes: 2 additions & 2 deletions cpp/tests/reductions/tdigest_tests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct reduce_op {
auto scalar_result =
cudf::reduce(values,
cudf::make_tdigest_aggregation<cudf::reduce_aggregation>(delta),
cudf::data_type{cudf::type_id::FLOAT64});
cudf::data_type{cudf::type_id::STRUCT});
auto tbl = static_cast<cudf::struct_scalar const*>(scalar_result.get())->view();
std::vector<std::unique_ptr<cudf::column>> cols;
std::transform(
Expand All @@ -54,7 +54,7 @@ struct reduce_merge_op {
auto scalar_result =
cudf::reduce(values,
cudf::make_merge_tdigest_aggregation<cudf::reduce_aggregation>(delta),
cudf::data_type{cudf::type_id::FLOAT64});
cudf::data_type{cudf::type_id::STRUCT});
auto tbl = static_cast<cudf::struct_scalar const*>(scalar_result.get())->view();
std::vector<std::unique_ptr<cudf::column>> cols;
std::transform(
Expand Down

0 comments on commit eb6a0c4

Please sign in to comment.