From fb22fc7ef4261bd803ccd9cba5d5e0335e8192e6 Mon Sep 17 00:00:00 2001 From: sperlingxx Date: Fri, 21 May 2021 18:21:49 +0800 Subject: [PATCH] fix list_concat_rows Signed-off-by: sperlingxx --- .../combine/concatenate_list_elements.cu | 40 ++++++++++++++++--- .../concatenate_list_elements_tests.cpp | 25 +++++++----- .../lists/combine/concatenate_rows_tests.cpp | 36 ++++++++++------- 3 files changed, 70 insertions(+), 31 deletions(-) diff --git a/cpp/src/lists/combine/concatenate_list_elements.cu b/cpp/src/lists/combine/concatenate_list_elements.cu index b76cd19d94b..c5a28a8ec5f 100644 --- a/cpp/src/lists/combine/concatenate_list_elements.cu +++ b/cpp/src/lists/combine/concatenate_list_elements.cu @@ -41,6 +41,7 @@ namespace { * concatenation. */ std::unique_ptr concatenate_lists_ignore_null(column_view const& input, + bool build_null_mask, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { @@ -50,9 +51,13 @@ std::unique_ptr concatenate_lists_ignore_null(column_view const& input, auto out_offsets = make_numeric_column( data_type{type_id::INT32}, num_rows + 1, mask_state::UNALLOCATED, stream, mr); + // The array of int8_t stores validities for the output list elements. + auto validities = rmm::device_uvector(build_null_mask ? num_rows : 0, stream); + auto const d_out_offsets = out_offsets->mutable_view().template begin(); auto const d_row_offsets = lists_column_view(input).offsets_begin(); auto const d_list_offsets = lists_column_view(lists_column_view(input).child()).offsets_begin(); + auto const lists_dv_ptr = column_device_view::create(lists_column_view(input).child()); // Concatenating the lists at the same row by converting the entry offsets from the child column // into row offsets of the root column. Those entry offsets are subtracted by the first entry @@ -62,7 +67,22 @@ std::unique_ptr concatenate_lists_ignore_null(column_view const& input, iter, iter + num_rows + 1, d_out_offsets, - [d_row_offsets, d_list_offsets] __device__(auto const idx) { + [d_row_offsets, + d_list_offsets, + lists_dv = *lists_dv_ptr, + d_validities = validities.begin(), + build_null_mask, + iter] __device__(auto const idx) { + if (build_null_mask) { + // The output row will be null only if all lists on the input row are null. + auto const is_valid = thrust::any_of(thrust::seq, + iter + d_row_offsets[idx], + iter + d_row_offsets[idx + 1], + [&] __device__(auto const list_idx) { + return lists_dv.is_valid(list_idx); + }); + d_validities[idx] = static_cast(is_valid); + } auto const start_offset = d_list_offsets[d_row_offsets[0]]; return d_list_offsets[d_row_offsets[idx]] - start_offset; }); @@ -71,11 +91,18 @@ std::unique_ptr concatenate_lists_ignore_null(column_view const& input, auto out_entries = std::make_unique( lists_column_view(lists_column_view(input).get_sliced_child(stream)).get_sliced_child(stream)); + auto [null_mask, null_count] = [&] { + return build_null_mask + ? cudf::detail::valid_if( + validities.begin(), validities.end(), thrust::identity{}, stream, mr) + : std::make_pair(cudf::detail::copy_bitmask(input, stream, mr), input.null_count()); + }(); + return make_lists_column(num_rows, std::move(out_offsets), std::move(out_entries), - input.null_count(), - cudf::detail::copy_bitmask(input, stream, mr), + null_count, + null_count > 0 ? std::move(null_mask) : rmm::device_buffer{}, stream, mr); } @@ -241,9 +268,10 @@ std::unique_ptr concatenate_list_elements(column_view const& input, if (input.size() == 0) { return cudf::empty_like(input); } - return (null_policy == concatenate_null_policy::IGNORE || - !lists_column_view(input).child().has_nulls()) - ? concatenate_lists_ignore_null(input, stream, mr) + bool has_null_list = lists_column_view(input).child().has_nulls(); + + return (null_policy == concatenate_null_policy::IGNORE || !has_null_list) + ? concatenate_lists_ignore_null(input, has_null_list, stream, mr) : concatenate_lists_nullifying_rows(input, stream, mr); } diff --git a/cpp/tests/lists/combine/concatenate_list_elements_tests.cpp b/cpp/tests/lists/combine/concatenate_list_elements_tests.cpp index de6307471a9..7d79cf4aebe 100644 --- a/cpp/tests/lists/combine/concatenate_list_elements_tests.cpp +++ b/cpp/tests/lists/combine/concatenate_list_elements_tests.cpp @@ -147,19 +147,23 @@ TYPED_TEST(ConcatenateListElementsTypedTest, SimpleInputWithNulls) auto row5 = ListsCol{ListsCol{{1, 2, 3, null}, null_at(3)}, ListsCol{{null}, null_at(0)}, ListsCol{{null, null, null, null, null}, all_nulls()}}; - auto const col = build_lists_col(row0, row1, row2, row3, row4, row5); + auto row6 = + ListsCol{{ListsCol{} /*NULL*/, ListsCol{} /*NULL*/, ListsCol{} /*NULL*/}, all_nulls()}; + auto const col = build_lists_col(row0, row1, row2, row3, row4, row5, row6); // Ignore null list elements. { auto const results = cudf::lists::concatenate_list_elements(col); auto const expected = - ListsCol{ListsCol{{1, null, 3, 4, 10, 11, 12, null}, null_at({1, 7})}, - ListsCol{{null, 2, 3, 4, 13, 14, 15, 16, 17, null, 20, null}, null_at({0, 9, 11})}, - ListsCol{{null, 2, 3, 4, null, 21, null, null}, null_at({0, 4, 6, 7})}, - ListsCol{{null, 18}, null_at(0)}, - ListsCol{{1, 2, null, 4, 19, 20, null, 22, 23, 24, 25}, null_at({2, 6})}, - ListsCol{{1, 2, 3, null, null, null, null, null, null, null}, - null_at({3, 4, 5, 6, 7, 8, 9})}}; + ListsCol{{ListsCol{{1, null, 3, 4, 10, 11, 12, null}, null_at({1, 7})}, + ListsCol{{null, 2, 3, 4, 13, 14, 15, 16, 17, null, 20, null}, null_at({0, 9, 11})}, + ListsCol{{null, 2, 3, 4, null, 21, null, null}, null_at({0, 4, 6, 7})}, + ListsCol{{null, 18}, null_at(0)}, + ListsCol{{1, 2, null, 4, 19, 20, null, 22, 23, 24, 25}, null_at({2, 6})}, + ListsCol{{1, 2, 3, null, null, null, null, null, null, null}, + null_at({3, 4, 5, 6, 7, 8, 9})}, + ListsCol{} /*NULL*/}, + null_at(6)}; CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *results, print_all); } @@ -174,8 +178,9 @@ TYPED_TEST(ConcatenateListElementsTypedTest, SimpleInputWithNulls) ListsCol{} /*NULL*/, ListsCol{{1, 2, null, 4, 19, 20, null, 22, 23, 24, 25}, null_at({2, 6})}, ListsCol{{1, 2, 3, null, null, null, null, null, null, null}, - null_at({3, 4, 5, 6, 7, 8, 9})}}, - null_at({0, 2, 3})}; + null_at({3, 4, 5, 6, 7, 8, 9})}, + ListsCol{} /*NULL*/}, + null_at({0, 2, 3, 6})}; CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *results, print_all); } } diff --git a/cpp/tests/lists/combine/concatenate_rows_tests.cpp b/cpp/tests/lists/combine/concatenate_rows_tests.cpp index 3e085af7740..af22f329634 100644 --- a/cpp/tests/lists/combine/concatenate_rows_tests.cpp +++ b/cpp/tests/lists/combine/concatenate_rows_tests.cpp @@ -184,24 +184,27 @@ TYPED_TEST(ListConcatenateRowsTypedTest, SimpleInputWithNulls) ListsCol{{null, 2, 3, 4}, null_at(0)}, ListsCol{} /*NULL*/, ListsCol{{1, 2, null, 4}, null_at(2)}, - ListsCol{{1, 2, 3, null}, null_at(3)}}, - null_at(3)} + ListsCol{{1, 2, 3, null}, null_at(3)}, + ListsCol{} /*NULL*/}, + null_at({3, 6})} .release(); auto const col2 = ListsCol{{ListsCol{{10, 11, 12, null}, null_at(3)}, ListsCol{{13, 14, 15, 16, 17, null}, null_at(5)}, ListsCol{} /*NULL*/, ListsCol{{null, 18}, null_at(0)}, ListsCol{{19, 20, null}, null_at(2)}, - ListsCol{{null}, null_at(0)}}, - null_at(2)} + ListsCol{{null}, null_at(0)}, + ListsCol{} /*NULL*/}, + null_at({2, 6})} .release(); auto const col3 = ListsCol{{ListsCol{} /*NULL*/, ListsCol{{20, null}, null_at(1)}, ListsCol{{null, 21, null, null}, null_at({0, 2, 3})}, ListsCol{}, ListsCol{22, 23, 24, 25}, - ListsCol{{null, null, null, null, null}, all_nulls()}}, - null_at(0)} + ListsCol{{null, null, null, null, null}, all_nulls()}, + ListsCol{} /*NULL*/}, + null_at({0, 6})} .release(); // Ignore null list elements @@ -209,13 +212,15 @@ TYPED_TEST(ListConcatenateRowsTypedTest, SimpleInputWithNulls) auto const results = cudf::lists::concatenate_rows(TView{{col1->view(), col2->view(), col3->view()}}); auto const expected = - ListsCol{ListsCol{{1, null, 3, 4, 10, 11, 12, null}, null_at({1, 7})}, - ListsCol{{null, 2, 3, 4, 13, 14, 15, 16, 17, null, 20, null}, null_at({0, 9, 11})}, - ListsCol{{null, 2, 3, 4, null, 21, null, null}, null_at({0, 4, 6, 7})}, - ListsCol{{null, 18}, null_at(0)}, - ListsCol{{1, 2, null, 4, 19, 20, null, 22, 23, 24, 25}, null_at({2, 6})}, - ListsCol{{1, 2, 3, null, null, null, null, null, null, null}, - null_at({3, 4, 5, 6, 7, 8, 9})}} + ListsCol{{ListsCol{{1, null, 3, 4, 10, 11, 12, null}, null_at({1, 7})}, + ListsCol{{null, 2, 3, 4, 13, 14, 15, 16, 17, null, 20, null}, null_at({0, 9, 11})}, + ListsCol{{null, 2, 3, 4, null, 21, null, null}, null_at({0, 4, 6, 7})}, + ListsCol{{null, 18}, null_at(0)}, + ListsCol{{1, 2, null, 4, 19, 20, null, 22, 23, 24, 25}, null_at({2, 6})}, + ListsCol{{1, 2, 3, null, null, null, null, null, null, null}, + null_at({3, 4, 5, 6, 7, 8, 9})}, + ListsCol{} /*NULL*/}, + null_at(6)} .release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*expected, *results, print_all); } @@ -232,8 +237,9 @@ TYPED_TEST(ListConcatenateRowsTypedTest, SimpleInputWithNulls) ListsCol{} /*NULL*/, ListsCol{{1, 2, null, 4, 19, 20, null, 22, 23, 24, 25}, null_at({2, 6})}, ListsCol{{1, 2, 3, null, null, null, null, null, null, null}, - null_at({3, 4, 5, 6, 7, 8, 9})}}, - null_at({0, 2, 3})} + null_at({3, 4, 5, 6, 7, 8, 9})}, + ListsCol{} /*NULL*/}, + null_at({0, 2, 3, 6})} .release(); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*expected, *results, print_all); }