Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wrong output for collect_list/collect_set of lists column #15243

Merged
merged 10 commits into from
Mar 13, 2024
13 changes: 7 additions & 6 deletions cpp/src/reductions/reductions.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -177,15 +177,16 @@ std::unique_ptr<scalar> reduce(column_view const& col,
std::move(*reduction::detail::make_empty_histogram_like(col.child(0))), true, stream, mr);
}

if (output_dtype.id() == type_id::LIST) {
if (col.type() == output_dtype) { return make_empty_scalar_like(col, stream, mr); }
// Under some circumstance, the output type will become the List of input type,
// such as: collect_list or collect_set. So, we have to handcraft the default scalar.
if (agg.kind == aggregation::COLLECT_LIST || agg.kind == aggregation::COLLECT_SET) {
auto scalar = make_list_scalar(empty_like(col)->view(), stream, mr);
scalar->set_valid_async(false, stream);
return scalar;
}
if (output_dtype.id() == type_id::STRUCT) { return make_empty_scalar_like(col, stream, mr); }

// `make_default_constructed_scalar` does not support nested type.
if (output_dtype.id() == type_id::STRUCT || output_dtype.id() == type_id::LIST) {
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
return make_empty_scalar_like(col, stream, mr);
}

auto result = make_default_constructed_scalar(output_dtype, stream, mr);
if (agg.kind == aggregation::ANY || agg.kind == aggregation::ALL) {
Expand Down
51 changes: 50 additions & 1 deletion cpp/tests/reductions/collect_ops_tests.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -367,3 +367,52 @@ TEST_F(CollectTest, CollectEmptys)
ret = collect_set(all_nulls, cudf::make_collect_set_aggregation<cudf::reduce_aggregation>());
CUDF_TEST_EXPECT_COLUMNS_EQUAL(int_col{}, dynamic_cast<cudf::list_scalar*>(ret.get())->view());
}

TEST_F(CollectTest, CollectAllNulls)
{
using int_col = cudf::test::fixed_width_column_wrapper<int32_t>;
using namespace cudf::test::iterators;

auto const input = int_col{{0, 0, 0, 0, 0, 0}, all_nulls()};
auto const expected = int_col{};

{
auto const agg =
cudf::make_collect_list_aggregation<cudf::reduce_aggregation>(cudf::null_policy::EXCLUDE);
auto const result = cudf::reduce(input, *agg, cudf::data_type{cudf::type_id::LIST});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected,
dynamic_cast<cudf::list_scalar*>(result.get())->view());
}
{
auto const agg = cudf::make_collect_set_aggregation<cudf::reduce_aggregation>(
cudf::null_policy::EXCLUDE, cudf::null_equality::UNEQUAL, cudf::nan_equality::ALL_EQUAL);
auto const result = cudf::reduce(input, *agg, cudf::data_type{cudf::type_id::LIST});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected,
dynamic_cast<cudf::list_scalar*>(result.get())->view());
}
}

TEST_F(CollectTest, CollectAllNullsWithLists)
{
using LCW = cudf::test::lists_column_wrapper<int32_t>;
using namespace cudf::test::iterators;

// list<list<int>>
auto const input = LCW{{LCW{LCW{1, 2, 3}, LCW{4, 5, 6}}, LCW{{1, 2, 3}}}, all_nulls()};
auto const expected = cudf::empty_like(input);

{
auto const agg =
cudf::make_collect_list_aggregation<cudf::reduce_aggregation>(cudf::null_policy::EXCLUDE);
auto const result = cudf::reduce(input, *agg, cudf::data_type{cudf::type_id::LIST});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected->view(),
dynamic_cast<cudf::list_scalar*>(result.get())->view());
}
{
auto const agg = cudf::make_collect_set_aggregation<cudf::reduce_aggregation>(
cudf::null_policy::EXCLUDE, cudf::null_equality::UNEQUAL, cudf::nan_equality::ALL_EQUAL);
auto const result = cudf::reduce(input, *agg, cudf::data_type{cudf::type_id::LIST});
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected->view(),
dynamic_cast<cudf::list_scalar*>(result.get())->view());
}
}
Loading