diff --git a/cpp/src/join/conditional_join.cu b/cpp/src/join/conditional_join.cu index 3aa54ebbfb3..f409d626935 100644 --- a/cpp/src/join/conditional_join.cu +++ b/cpp/src/join/conditional_join.cu @@ -50,9 +50,8 @@ conditional_join(table_view const& left, // null index for the right table; in others, we return an empty output. if (right.num_rows() == 0) { switch (join_type) { - // Left, left anti, and full (which are effectively left because we are - // guaranteed that left has more rows than right) all return a all the - // row indices from left with a corresponding NULL from the right. + // Left, left anti, and full all return all the row indices from left + // with a corresponding NULL from the right. case join_kind::LEFT_JOIN: case join_kind::LEFT_ANTI_JOIN: case join_kind::FULL_JOIN: return get_trivial_left_join_indices(left, stream); @@ -61,6 +60,23 @@ conditional_join(table_view const& left, case join_kind::LEFT_SEMI_JOIN: return std::make_pair(std::make_unique>(0, stream, mr), std::make_unique>(0, stream, mr)); + default: CUDF_FAIL("Invalid join kind."); break; + } + } else if (left.num_rows() == 0) { + switch (join_type) { + // Left, left anti, left semi, and inner joins all return empty sets. + case join_kind::LEFT_JOIN: + case join_kind::LEFT_ANTI_JOIN: + case join_kind::INNER_JOIN: + case join_kind::LEFT_SEMI_JOIN: + return std::make_pair(std::make_unique>(0, stream, mr), + std::make_unique>(0, stream, mr)); + // Full joins need to return the trivial complement. + case join_kind::FULL_JOIN: { + auto ret_flipped = get_trivial_left_join_indices(right, stream); + return std::make_pair(std::move(ret_flipped.second), std::move(ret_flipped.first)); + } + default: CUDF_FAIL("Invalid join kind."); break; } } @@ -113,7 +129,12 @@ conditional_join(table_view const& left, join_size = size.value(stream); } - // If the output size will be zero, we can return immediately. + // The initial early exit clauses guarantee that we will not reach this point + // unless both the left and right tables are non-empty. Under that + // constraint, neither left nor full joins can return an empty result since + // at minimum we are guaranteed null matches for all non-matching rows. In + // all other cases (inner, left semi, and left anti joins) if we reach this + // point we can safely return an empty result. if (join_size == 0) { return std::make_pair(std::make_unique>(0, stream, mr), std::make_unique>(0, stream, mr)); @@ -174,20 +195,31 @@ std::size_t compute_conditional_join_output_size(table_view const& left, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - // We can immediately filter out cases where the right table is empty. In - // some cases, we return all the rows of the left table with a corresponding - // null index for the right table; in others, we return an empty output. + // We can immediately filter out cases where one table is empty. In + // some cases, we return all the rows of the other table with a corresponding + // null index for the empty table; in others, we return an empty output. if (right.num_rows() == 0) { switch (join_type) { - // Left, left anti, and full (which are effectively left because we are - // guaranteed that left has more rows than right) all return a all the - // row indices from left with a corresponding NULL from the right. + // Left, left anti, and full all return all the row indices from left + // with a corresponding NULL from the right. case join_kind::LEFT_JOIN: case join_kind::LEFT_ANTI_JOIN: case join_kind::FULL_JOIN: return left.num_rows(); // Inner and left semi joins return empty output because no matches can exist. case join_kind::INNER_JOIN: case join_kind::LEFT_SEMI_JOIN: return 0; + default: CUDF_FAIL("Invalid join kind."); break; + } + } else if (left.num_rows() == 0) { + switch (join_type) { + // Left, left anti, left semi, and inner joins all return empty sets. + case join_kind::LEFT_JOIN: + case join_kind::LEFT_ANTI_JOIN: + case join_kind::INNER_JOIN: + case join_kind::LEFT_SEMI_JOIN: return 0; + // Full joins need to return the trivial complement. + case join_kind::FULL_JOIN: return right.num_rows(); + default: CUDF_FAIL("Invalid join kind."); break; } } diff --git a/cpp/tests/join/conditional_join_tests.cu b/cpp/tests/join/conditional_join_tests.cu index d566d2086bb..6c73fb67d7e 100644 --- a/cpp/tests/join/conditional_join_tests.cu +++ b/cpp/tests/join/conditional_join_tests.cu @@ -304,6 +304,11 @@ TYPED_TEST(ConditionalInnerJoinTest, TestOneColumnOneRowAllEqual) this->test({{0}}, {{0}}, left_zero_eq_right_zero, {{0, 0}}); }; +TYPED_TEST(ConditionalInnerJoinTest, TestOneColumnLeftEmpty) +{ + this->test({{}}, {{3, 4, 5}}, left_zero_eq_right_zero, {}); +}; + TYPED_TEST(ConditionalInnerJoinTest, TestOneColumnTwoRowAllEqual) { this->test({{0, 1}}, {{0, 0}}, left_zero_eq_right_zero, {{0, 0}, {0, 1}}); @@ -489,6 +494,11 @@ TYPED_TEST(ConditionalLeftJoinTest, TestTwoColumnThreeRowSomeEqual) {{0, 0}, {1, 1}, {2, JoinNoneValue}}); }; +TYPED_TEST(ConditionalLeftJoinTest, TestOneColumnLeftEmpty) +{ + this->test({{}}, {{3, 4, 5}}, left_zero_eq_right_zero, {}); +}; + TYPED_TEST(ConditionalLeftJoinTest, TestCompareRandomToHash) { // Generate columns of 10 repeats of the integer range [0, 10), then merge @@ -560,6 +570,14 @@ TYPED_TEST(ConditionalFullJoinTest, TestOneColumnNoneEqual) {JoinNoneValue, 2}}); }; +TYPED_TEST(ConditionalFullJoinTest, TestOneColumnLeftEmpty) +{ + this->test({{}}, + {{3, 4, 5}}, + left_zero_eq_right_zero, + {{JoinNoneValue, 0}, {JoinNoneValue, 1}, {JoinNoneValue, 2}}); +}; + TYPED_TEST(ConditionalFullJoinTest, TestTwoColumnThreeRowSomeEqual) { this->test({{0, 1, 2}, {10, 20, 30}},