Skip to content

Commit

Permalink
Improve hash join detail functions (#10273)
Browse files Browse the repository at this point in the history
This PR includes several changes in the hash join `detail` functions:

- Fixes a bug where public join APIs were invoked in `detail` join functions. External invocations are replaced with the corresponding `detail` ones.
- Uses structured bindings to improve code readability
- Add an early exit before the actual join operation to improve performance

Authors:
  - Yunsong Wang (https://github.com/PointKernel)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #10273
  • Loading branch information
PointKernel authored Feb 12, 2022
1 parent 317553f commit 7f2a16f
Showing 1 changed file with 34 additions and 31 deletions.
65 changes: 34 additions & 31 deletions cpp/src/join/join.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <join/hash_join.cuh>
#include <join/join_common_utils.hpp>
#include "join/hash_join.cuh"
#include "join/join_common_utils.hpp"

#include <cudf/detail/gather.cuh>
#include <cudf/dictionary/detail/update_keys.hpp>
Expand Down Expand Up @@ -51,8 +51,8 @@ inner_join(table_view const& left_input,
// build the hash map from the smaller table.
if (right.num_rows() > left.num_rows()) {
cudf::hash_join hj_obj(left, compare_nulls, stream);
auto result = hj_obj.inner_join(right, std::nullopt, stream, mr);
return std::make_pair(std::move(result.second), std::move(result.first));
auto [right_result, left_result] = hj_obj.inner_join(right, std::nullopt, stream, mr);
return std::make_pair(std::move(left_result), std::move(right_result));
} else {
cudf::hash_join hj_obj(right, compare_nulls, stream);
return hj_obj.inner_join(left, std::nullopt, stream, mr);
Expand All @@ -78,16 +78,17 @@ std::unique_ptr<table> inner_join(table_view const& left_input,
auto const left = scatter_columns(matched.second.front(), left_on, left_input);
auto const right = scatter_columns(matched.second.back(), right_on, right_input);

auto join_indices = inner_join(left.select(left_on), right.select(right_on), compare_nulls, mr);
auto const [left_join_indices, right_join_indices] = cudf::detail::inner_join(
left.select(left_on), right.select(right_on), compare_nulls, stream, mr);
std::unique_ptr<table> left_result = detail::gather(left,
join_indices.first->begin(),
join_indices.first->end(),
left_join_indices->begin(),
left_join_indices->end(),
out_of_bounds_policy::DONT_CHECK,
stream,
mr);
std::unique_ptr<table> right_result = detail::gather(right,
join_indices.second->begin(),
join_indices.second->end(),
right_join_indices->begin(),
right_join_indices->end(),
out_of_bounds_policy::DONT_CHECK,
stream,
mr);
Expand Down Expand Up @@ -134,23 +135,24 @@ std::unique_ptr<table> left_join(table_view const& left_input,
table_view const left = scatter_columns(matched.second.front(), left_on, left_input);
table_view const right = scatter_columns(matched.second.back(), right_on, right_input);

auto join_indices = left_join(left.select(left_on), right.select(right_on), compare_nulls);

if ((left_on.empty() || right_on.empty()) ||
is_trivial_join(left, right, cudf::detail::join_kind::LEFT_JOIN)) {
auto probe_build_pair = get_empty_joined_table(left, right);
return cudf::detail::combine_table_pair(std::move(probe_build_pair.first),
std::move(probe_build_pair.second));
if ((left_on.empty() or right_on.empty()) or
cudf::detail::is_trivial_join(left, right, cudf::detail::join_kind::LEFT_JOIN)) {
auto [left_empty_table, right_empty_table] = get_empty_joined_table(left, right);
return cudf::detail::combine_table_pair(std::move(left_empty_table),
std::move(right_empty_table));
}

auto const [left_join_indices, right_join_indices] = cudf::detail::left_join(
left.select(left_on), right.select(right_on), compare_nulls, stream, mr);
std::unique_ptr<table> left_result = detail::gather(left,
join_indices.first->begin(),
join_indices.first->end(),
left_join_indices->begin(),
left_join_indices->end(),
out_of_bounds_policy::NULLIFY,
stream,
mr);
std::unique_ptr<table> right_result = detail::gather(right,
join_indices.second->begin(),
join_indices.second->end(),
right_join_indices->begin(),
right_join_indices->end(),
out_of_bounds_policy::NULLIFY,
stream,
mr);
Expand Down Expand Up @@ -197,23 +199,24 @@ std::unique_ptr<table> full_join(table_view const& left_input,
table_view const left = scatter_columns(matched.second.front(), left_on, left_input);
table_view const right = scatter_columns(matched.second.back(), right_on, right_input);

auto join_indices = full_join(left.select(left_on), right.select(right_on), compare_nulls);

if ((left_on.empty() || right_on.empty()) ||
is_trivial_join(left, right, cudf::detail::join_kind::FULL_JOIN)) {
auto probe_build_pair = get_empty_joined_table(left, right);
return cudf::detail::combine_table_pair(std::move(probe_build_pair.first),
std::move(probe_build_pair.second));
if ((left_on.empty() or right_on.empty()) or
cudf::detail::is_trivial_join(left, right, cudf::detail::join_kind::FULL_JOIN)) {
auto [left_empty_table, right_empty_table] = get_empty_joined_table(left, right);
return cudf::detail::combine_table_pair(std::move(left_empty_table),
std::move(right_empty_table));
}

auto const [left_join_indices, right_join_indices] = cudf::detail::full_join(
left.select(left_on), right.select(right_on), compare_nulls, stream, mr);
std::unique_ptr<table> left_result = detail::gather(left,
join_indices.first->begin(),
join_indices.first->end(),
left_join_indices->begin(),
left_join_indices->end(),
out_of_bounds_policy::NULLIFY,
stream,
mr);
std::unique_ptr<table> right_result = detail::gather(right,
join_indices.second->begin(),
join_indices.second->end(),
right_join_indices->begin(),
right_join_indices->end(),
out_of_bounds_policy::NULLIFY,
stream,
mr);
Expand Down

0 comments on commit 7f2a16f

Please sign in to comment.