diff --git a/cpp/src/join/join.cu b/cpp/src/join/join.cu index ef9e7867a2d..7a478ca2eb3 100644 --- a/cpp/src/join/join.cu +++ b/cpp/src/join/join.cu @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include +#include "join/hash_join.cuh" +#include "join/join_common_utils.hpp" #include #include @@ -51,8 +51,8 @@ inner_join(table_view const& left_input, // build the hash map from the smaller table. if (right.num_rows() > left.num_rows()) { cudf::hash_join hj_obj(left, compare_nulls, stream); - auto result = hj_obj.inner_join(right, std::nullopt, stream, mr); - return std::make_pair(std::move(result.second), std::move(result.first)); + auto [right_result, left_result] = hj_obj.inner_join(right, std::nullopt, stream, mr); + return std::make_pair(std::move(left_result), std::move(right_result)); } else { cudf::hash_join hj_obj(right, compare_nulls, stream); return hj_obj.inner_join(left, std::nullopt, stream, mr); @@ -78,16 +78,17 @@ std::unique_ptr inner_join(table_view const& left_input, auto const left = scatter_columns(matched.second.front(), left_on, left_input); auto const right = scatter_columns(matched.second.back(), right_on, right_input); - auto join_indices = inner_join(left.select(left_on), right.select(right_on), compare_nulls, mr); + auto const [left_join_indices, right_join_indices] = cudf::detail::inner_join( + left.select(left_on), right.select(right_on), compare_nulls, stream, mr); std::unique_ptr
left_result = detail::gather(left, - join_indices.first->begin(), - join_indices.first->end(), + left_join_indices->begin(), + left_join_indices->end(), out_of_bounds_policy::DONT_CHECK, stream, mr); std::unique_ptr
right_result = detail::gather(right, - join_indices.second->begin(), - join_indices.second->end(), + right_join_indices->begin(), + right_join_indices->end(), out_of_bounds_policy::DONT_CHECK, stream, mr); @@ -134,23 +135,24 @@ std::unique_ptr
left_join(table_view const& left_input, table_view const left = scatter_columns(matched.second.front(), left_on, left_input); table_view const right = scatter_columns(matched.second.back(), right_on, right_input); - auto join_indices = left_join(left.select(left_on), right.select(right_on), compare_nulls); - - if ((left_on.empty() || right_on.empty()) || - is_trivial_join(left, right, cudf::detail::join_kind::LEFT_JOIN)) { - auto probe_build_pair = get_empty_joined_table(left, right); - return cudf::detail::combine_table_pair(std::move(probe_build_pair.first), - std::move(probe_build_pair.second)); + if ((left_on.empty() or right_on.empty()) or + cudf::detail::is_trivial_join(left, right, cudf::detail::join_kind::LEFT_JOIN)) { + auto [left_empty_table, right_empty_table] = get_empty_joined_table(left, right); + return cudf::detail::combine_table_pair(std::move(left_empty_table), + std::move(right_empty_table)); } + + auto const [left_join_indices, right_join_indices] = cudf::detail::left_join( + left.select(left_on), right.select(right_on), compare_nulls, stream, mr); std::unique_ptr
left_result = detail::gather(left, - join_indices.first->begin(), - join_indices.first->end(), + left_join_indices->begin(), + left_join_indices->end(), out_of_bounds_policy::NULLIFY, stream, mr); std::unique_ptr
right_result = detail::gather(right, - join_indices.second->begin(), - join_indices.second->end(), + right_join_indices->begin(), + right_join_indices->end(), out_of_bounds_policy::NULLIFY, stream, mr); @@ -197,23 +199,24 @@ std::unique_ptr
full_join(table_view const& left_input, table_view const left = scatter_columns(matched.second.front(), left_on, left_input); table_view const right = scatter_columns(matched.second.back(), right_on, right_input); - auto join_indices = full_join(left.select(left_on), right.select(right_on), compare_nulls); - - if ((left_on.empty() || right_on.empty()) || - is_trivial_join(left, right, cudf::detail::join_kind::FULL_JOIN)) { - auto probe_build_pair = get_empty_joined_table(left, right); - return cudf::detail::combine_table_pair(std::move(probe_build_pair.first), - std::move(probe_build_pair.second)); + if ((left_on.empty() or right_on.empty()) or + cudf::detail::is_trivial_join(left, right, cudf::detail::join_kind::FULL_JOIN)) { + auto [left_empty_table, right_empty_table] = get_empty_joined_table(left, right); + return cudf::detail::combine_table_pair(std::move(left_empty_table), + std::move(right_empty_table)); } + + auto const [left_join_indices, right_join_indices] = cudf::detail::full_join( + left.select(left_on), right.select(right_on), compare_nulls, stream, mr); std::unique_ptr
left_result = detail::gather(left, - join_indices.first->begin(), - join_indices.first->end(), + left_join_indices->begin(), + left_join_indices->end(), out_of_bounds_policy::NULLIFY, stream, mr); std::unique_ptr
right_result = detail::gather(right, - join_indices.second->begin(), - join_indices.second->end(), + right_join_indices->begin(), + right_join_indices->end(), out_of_bounds_policy::NULLIFY, stream, mr);