diff --git a/cpp/src/join/hash_join.cu b/cpp/src/join/hash_join.cu index 50cc479fcf4..ee1eaeaed47 100644 --- a/cpp/src/join/hash_join.cu +++ b/cpp/src/join/hash_join.cu @@ -349,11 +349,15 @@ std::size_t hash_join::hash_join_impl::inner_join_size(cudf::table_view const& p CUDF_FUNC_RANGE(); CUDF_EXPECTS(_hash_table, "Hash table of hash join is null."); - auto build_table = cudf::table_device_view::create(_build, stream); - auto probe_table = cudf::table_device_view::create(probe, stream); + auto flattened_probe = structs::detail::flatten_nested_columns( + probe, {}, {}, structs::detail::column_nullability::FORCE); + auto const flattened_probe_table = std::get<0>(flattened_probe); + + auto build_table_ptr = cudf::table_device_view::create(_build, stream); + auto flattened_probe_table_ptr = cudf::table_device_view::create(flattened_probe_table, stream); return cudf::detail::compute_join_output_size( - *build_table, *probe_table, *_hash_table, compare_nulls, stream); + *build_table_ptr, *flattened_probe_table_ptr, *_hash_table, compare_nulls, stream); } std::size_t hash_join::hash_join_impl::left_join_size(cudf::table_view const& probe, @@ -365,11 +369,15 @@ std::size_t hash_join::hash_join_impl::left_join_size(cudf::table_view const& pr // Trivial left join case - exit early if (!_hash_table) { return probe.num_rows(); } - auto build_table = cudf::table_device_view::create(_build, stream); - auto probe_table = cudf::table_device_view::create(probe, stream); + auto flattened_probe = structs::detail::flatten_nested_columns( + probe, {}, {}, structs::detail::column_nullability::FORCE); + auto const flattened_probe_table = std::get<0>(flattened_probe); + + auto build_table_ptr = cudf::table_device_view::create(_build, stream); + auto flattened_probe_table_ptr = cudf::table_device_view::create(flattened_probe_table, stream); return cudf::detail::compute_join_output_size( - *build_table, *probe_table, *_hash_table, compare_nulls, stream); + *build_table_ptr, *flattened_probe_table_ptr, *_hash_table, compare_nulls, stream); } std::size_t hash_join::hash_join_impl::full_join_size(cudf::table_view const& probe, @@ -382,10 +390,15 @@ std::size_t hash_join::hash_join_impl::full_join_size(cudf::table_view const& pr // Trivial left join case - exit early if (!_hash_table) { return probe.num_rows(); } - auto build_table = cudf::table_device_view::create(_build, stream); - auto probe_table = cudf::table_device_view::create(probe, stream); + auto flattened_probe = structs::detail::flatten_nested_columns( + probe, {}, {}, structs::detail::column_nullability::FORCE); + auto const flattened_probe_table = std::get<0>(flattened_probe); + + auto build_table_ptr = cudf::table_device_view::create(_build, stream); + auto flattened_probe_table_ptr = cudf::table_device_view::create(flattened_probe_table, stream); - return get_full_join_size(*build_table, *probe_table, *_hash_table, compare_nulls, stream, mr); + return get_full_join_size( + *build_table_ptr, *flattened_probe_table_ptr, *_hash_table, compare_nulls, stream, mr); } template diff --git a/cpp/tests/join/join_tests.cpp b/cpp/tests/join/join_tests.cpp index e468368842a..af998e366e9 100644 --- a/cpp/tests/join/join_tests.cpp +++ b/cpp/tests/join/join_tests.cpp @@ -44,6 +44,28 @@ constexpr cudf::size_type NoneValue = std::numeric_limits::min(); // TODO: how to test if this isn't public? struct JoinTest : public cudf::test::BaseFixture { + std::pair, std::unique_ptr> gather_maps_as_tables( + cudf::column_view const& expected_left_map, + cudf::column_view const& expected_right_map, + std::pair>, + std::unique_ptr>> const& result) + { + auto result_table = + cudf::table_view({cudf::column_view{cudf::data_type{cudf::type_id::INT32}, + static_cast(result.first->size()), + result.first->data()}, + cudf::column_view{cudf::data_type{cudf::type_id::INT32}, + static_cast(result.second->size()), + result.second->data()}}); + auto result_sort_order = cudf::sorted_order(result_table); + auto sorted_result = cudf::gather(result_table, *result_sort_order); + + cudf::table_view gold({expected_left_map, expected_right_map}); + auto gold_sort_order = cudf::sorted_order(gold); + auto sorted_gold = cudf::gather(gold, *gold_sort_order); + + return std::make_pair(std::move(sorted_gold), std::move(sorted_result)); + } }; TEST_F(JoinTest, EmptySentinelRepro) @@ -1232,27 +1254,9 @@ TEST_F(JoinTest, HashJoinSequentialProbes) EXPECT_EQ(output_size, size_gold); auto result = hash_join.full_join(t0, cudf::null_equality::EQUAL, optional_size); - auto result_table = - cudf::table_view({cudf::column_view{cudf::data_type{cudf::type_id::INT32}, - static_cast(result.first->size()), - result.first->data()}, - cudf::column_view{cudf::data_type{cudf::type_id::INT32}, - static_cast(result.second->size()), - result.second->data()}}); - auto result_sort_order = cudf::sorted_order(result_table); - auto sorted_result = cudf::gather(result_table, *result_sort_order); - column_wrapper col_gold_0{{NoneValue, NoneValue, NoneValue, NoneValue, 4, 0, 1, 2, 3}}; column_wrapper col_gold_1{{0, 1, 2, 3, 4, NoneValue, NoneValue, NoneValue, NoneValue}}; - - CVector cols_gold; - cols_gold.push_back(col_gold_0.release()); - cols_gold.push_back(col_gold_1.release()); - - Table gold(std::move(cols_gold)); - auto gold_sort_order = cudf::sorted_order(gold.view()); - auto sorted_gold = cudf::gather(gold.view(), *gold_sort_order); - + auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*sorted_gold, *sorted_result); } @@ -1270,27 +1274,9 @@ TEST_F(JoinTest, HashJoinSequentialProbes) EXPECT_EQ(output_size, size_gold); auto result = hash_join.left_join(t0, cudf::null_equality::EQUAL, optional_size); - auto result_table = - cudf::table_view({cudf::column_view{cudf::data_type{cudf::type_id::INT32}, - static_cast(result.first->size()), - result.first->data()}, - cudf::column_view{cudf::data_type{cudf::type_id::INT32}, - static_cast(result.second->size()), - result.second->data()}}); - auto result_sort_order = cudf::sorted_order(result_table); - auto sorted_result = cudf::gather(result_table, *result_sort_order); - column_wrapper col_gold_0{{0, 1, 2, 3, 4}}; column_wrapper col_gold_1{{NoneValue, NoneValue, NoneValue, NoneValue, 4}}; - - CVector cols_gold; - cols_gold.push_back(col_gold_0.release()); - cols_gold.push_back(col_gold_1.release()); - - Table gold(std::move(cols_gold)); - auto gold_sort_order = cudf::sorted_order(gold.view()); - auto sorted_gold = cudf::gather(gold.view(), *gold_sort_order); - + auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*sorted_gold, *sorted_result); } @@ -1308,27 +1294,69 @@ TEST_F(JoinTest, HashJoinSequentialProbes) EXPECT_EQ(output_size, size_gold); auto result = hash_join.inner_join(t0, cudf::null_equality::EQUAL, optional_size); - auto result_table = - cudf::table_view({cudf::column_view{cudf::data_type{cudf::type_id::INT32}, - static_cast(result.first->size()), - result.first->data()}, - cudf::column_view{cudf::data_type{cudf::type_id::INT32}, - static_cast(result.second->size()), - result.second->data()}}); - auto result_sort_order = cudf::sorted_order(result_table); - auto sorted_result = cudf::gather(result_table, *result_sort_order); - column_wrapper col_gold_0{{2, 4, 0}}; column_wrapper col_gold_1{{1, 1, 4}}; + auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*sorted_gold, *sorted_result); + } +} - CVector cols_gold; - cols_gold.push_back(col_gold_0.release()); - cols_gold.push_back(col_gold_1.release()); +TEST_F(JoinTest, HashJoinWithStructsAndNulls) +{ + auto col0_names_col = strcol_wrapper{ + "Samuel Vimes", "Carrot Ironfoundersson", "Detritus", "Samuel Vimes", "Angua von Überwald"}; + auto col0_ages_col = column_wrapper{{48, 27, 351, 31, 25}}; + + auto col0_is_human_col = column_wrapper{{true, true, false, false, false}, {1, 1, 0, 1, 0}}; - Table gold(std::move(cols_gold)); - auto gold_sort_order = cudf::sorted_order(gold.view()); - auto sorted_gold = cudf::gather(gold.view(), *gold_sort_order); + auto col0 = + cudf::test::structs_column_wrapper{{col0_names_col, col0_ages_col, col0_is_human_col}}; + + auto col1_names_col = strcol_wrapper{ + "Samuel Vimes", "Detritus", "Detritus", "Carrot Ironfoundersson", "Angua von Überwald"}; + auto col1_ages_col = column_wrapper{{48, 35, 351, 22, 25}}; + auto col1_is_human_col = column_wrapper{{true, true, false, false, true}, {1, 1, 0, 1, 1}}; + + auto col1 = + cudf::test::structs_column_wrapper{{col1_names_col, col1_ages_col, col1_is_human_col}}; + + CVector cols0, cols1; + cols0.push_back(col0.release()); + cols1.push_back(col1.release()); + + Table t0(std::move(cols0)); + Table t1(std::move(cols1)); + + auto hash_join = cudf::hash_join(t1, cudf::null_equality::EQUAL); + + { + auto output_size = hash_join.left_join_size(t0); + EXPECT_EQ(5, output_size); + auto result = hash_join.left_join(t0, cudf::null_equality::EQUAL, output_size); + column_wrapper col_gold_0{{0, 1, 2, 3, 4}}; + column_wrapper col_gold_1{{0, NoneValue, 2, NoneValue, NoneValue}}; + auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*sorted_gold, *sorted_result); + } + + { + auto output_size = hash_join.inner_join_size(t0); + EXPECT_EQ(2, output_size); + auto result = hash_join.inner_join(t0, cudf::null_equality::EQUAL, output_size); + column_wrapper col_gold_0{{0, 2}}; + column_wrapper col_gold_1{{0, 2}}; + auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); + CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*sorted_gold, *sorted_result); + } + + { + auto output_size = hash_join.full_join_size(t0); + EXPECT_EQ(8, output_size); + auto result = hash_join.full_join(t0, cudf::null_equality::EQUAL, output_size); + column_wrapper col_gold_0{{NoneValue, NoneValue, NoneValue, 0, 1, 2, 3, 4}}; + column_wrapper col_gold_1{{1, 3, 4, 0, NoneValue, 2, NoneValue, NoneValue}}; + auto const [sorted_gold, sorted_result] = gather_maps_as_tables(col_gold_0, col_gold_1, result); CUDF_TEST_EXPECT_TABLES_EQUIVALENT(*sorted_gold, *sorted_result); } }