Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose and then implement support for cross joins in cudf-polars #16097

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/join.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ cpdef Column left_anti_join(
Table right_keys,
null_equality nulls_equal
)

cpdef Table cross_join(Table left, Table right)
30 changes: 25 additions & 5 deletions python/cudf/cudf/_lib/pylibcudf/join.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from cython.operator import dereference

from libcpp.memory cimport make_unique
from libcpp.memory cimport make_unique, unique_ptr
from libcpp.utility cimport move

from rmm._lib.device_buffer cimport device_buffer

from cudf._lib.pylibcudf.libcudf cimport join as cpp_join
from cudf._lib.pylibcudf.libcudf.column.column cimport column
from cudf._lib.pylibcudf.libcudf.table.table cimport table
from cudf._lib.pylibcudf.libcudf.types cimport (
data_type,
null_equality,
Expand Down Expand Up @@ -88,7 +89,6 @@ cpdef tuple left_join(
nulls_equal : NullEquality
Should nulls compare equal?


Returns
-------
Tuple[Column, Column]
Expand Down Expand Up @@ -122,7 +122,6 @@ cpdef tuple full_join(
nulls_equal : NullEquality
Should nulls compare equal?


Returns
-------
Tuple[Column, Column]
Expand Down Expand Up @@ -156,7 +155,6 @@ cpdef Column left_semi_join(
nulls_equal : NullEquality
Should nulls compare equal?


Returns
-------
Column
Expand Down Expand Up @@ -190,7 +188,6 @@ cpdef Column left_anti_join(
nulls_equal : NullEquality
Should nulls compare equal?


Returns
-------
Column
Expand All @@ -204,3 +201,26 @@ cpdef Column left_anti_join(
nulls_equal
)
return _column_from_gather_map(move(c_result))


cpdef Table cross_join(Table left, Table right):
"""Perform a cross join on two tables.

For details see :cpp:func:`cross_join`.

Parameters
----------
left : Table
The left table to join.
right: Table
The right table to join.

Returns
-------
Table
The result of cross joining the two inputs.
"""
cdef unique_ptr[table] result
with nogil:
result = move(cpp_join.cross_join(left.view(), right.view()))
return Table.from_libcudf(move(result))
5 changes: 5 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/libcudf/join.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,8 @@ cdef extern from "cudf/join.hpp" namespace "cudf" nogil:
const table_view right_keys,
null_equality nulls_equal,
) except +

cdef unique_ptr[table] cross_join(
const table_view left,
const table_view right,
) except +
29 changes: 29 additions & 0 deletions python/cudf/cudf/pylibcudf_tests/test_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import numpy as np
import pyarrow as pa
from utils import assert_table_eq

from cudf._lib import pylibcudf as plc


def test_cross_join():
left = pa.Table.from_arrays([[0, 1, 2], [3, 4, 5]], names=["a", "b"])
right = pa.Table.from_arrays(
[[6, 7, 8, 9], [10, 11, 12, 13]], names=["c", "d"]
)

pleft = plc.interop.from_arrow(left)
pright = plc.interop.from_arrow(right)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might consider fixturizing at least the GPU objects here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that there's only one value, and one test, I am inclined not to, unless there is a really compelling reason.


expect = pa.Table.from_arrays(
[
*(np.repeat(c.to_numpy(), len(right)) for c in left.columns),
*(np.tile(c.to_numpy(), len(left)) for c in right.columns),
],
names=["a", "b", "c", "d"],
)

got = plc.join.cross_join(pleft, pright)

assert_table_eq(expect, got)
29 changes: 21 additions & 8 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ class Join(IR):
right_on: list[expr.NamedExpr]
"""List of expressions used as keys in the right frame."""
options: tuple[
Literal["inner", "left", "full", "leftsemi", "leftanti"],
Literal["inner", "left", "full", "leftsemi", "leftanti", "cross"],
bool,
tuple[int, int] | None,
str | None,
Expand All @@ -518,11 +518,6 @@ class Join(IR):
- coalesce: should key columns be coalesced (only makes sense for outer joins)
"""

def __post_init__(self) -> None:
"""Validate preconditions."""
if self.options[0] == "cross":
raise NotImplementedError("cross join not implemented")

@cache
@staticmethod
def _joiners(
Expand Down Expand Up @@ -567,6 +562,26 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
left = self.left.evaluate(cache=cache)
right = self.right.evaluate(cache=cache)
how, join_nulls, zlice, suffix, coalesce = self.options
suffix = "_right" if suffix is None else suffix
if how == "cross":
# Separate implementation, since cross_join returns the
# result, not the gather maps
columns = plc.join.cross_join(left.table, right.table).columns()
left_cols = [
NamedColumn(new, old.name).sorted_like(old)
for new, old in zip(columns[: left.num_columns], left.columns)
]
right_cols = [
NamedColumn(
new,
old.name
if old.name not in left.column_names_set
else f"{old.name}{suffix}",
)
for new, old in zip(columns[left.num_columns :], right.columns)
]
return DataFrame([*left_cols, *right_cols])
left_on = DataFrame(
broadcast(
*(e.evaluate(left) for e in self.left_on), target_length=left.num_rows
Expand All @@ -578,13 +593,11 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
target_length=right.num_rows,
)
)
how, join_nulls, zlice, suffix, coalesce = self.options
null_equality = (
plc.types.NullEquality.EQUAL
if join_nulls
else plc.types.NullEquality.UNEQUAL
)
suffix = "_right" if suffix is None else suffix
join_fn, left_policy, right_policy = Join._joiners(how)
if right_policy is None:
# Semi join
Expand Down
24 changes: 20 additions & 4 deletions python/cudf_polars/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
"left",
"semi",
"anti",
pytest.param(
"cross",
marks=pytest.mark.xfail(reason="cross join not implemented"),
),
wence- marked this conversation as resolved.
Show resolved Hide resolved
"full",
],
)
Expand Down Expand Up @@ -55,3 +51,23 @@ def test_join(how, coalesce, join_nulls, join_expr):
right, on=join_expr, how=how, join_nulls=join_nulls, coalesce=coalesce
)
assert_gpu_result_equal(query, check_row_order=False)


def test_cross_join():
left = pl.DataFrame(
{
"a": [1, 2, 3, 1, None],
"b": [1, 2, 3, 4, 5],
"c": [2, 3, 4, 5, 6],
}
).lazy()
right = pl.DataFrame(
{
"a": [1, 4, 3, 7, None, None],
"c": [2, 3, 4, 5, 6, 7],
}
).lazy()

q = left.join(right, how="cross")

assert_gpu_result_equal(q)
Loading