-
Notifications
You must be signed in to change notification settings - Fork 917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use new polars engine config object in cudf-polars callback #16347
Changes from 8 commits
7742b8b
ef0b49f
e9fd96d
9d69621
918a40e
f8f2d0d
bcedb6b
6f2d406
f3bbd3f
1d4c30c
abcf22b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,9 +10,33 @@ | |
|
||
from __future__ import annotations | ||
|
||
from cudf_polars._version import __git_commit__, __version__ | ||
from cudf_polars.callback import execute_with_cudf | ||
from cudf_polars.dsl.translate import translate_ir | ||
import os | ||
import warnings | ||
|
||
# We want to avoid initialising the GPU on import. Unfortunately, | ||
# while we still depend on cudf, the default mode is to check things. | ||
# If we set RAPIDS_NO_INITIALIZE, then cudf doesn't do import-time | ||
# validation, good. | ||
# We additionally must set the ptxcompiler environment variable, so | ||
# that we don't check if a numba patch is needed. But if this is done, | ||
# then the patching mechanism warns, and we want to squash that | ||
# warning too. | ||
# TODO: Remove this when we only depend on a pylibcudf package. | ||
os.environ["RAPIDS_NO_INITIALIZE"] = "1" | ||
os.environ["PTXCOMPILER_CHECK_NUMBA_CODEGEN_PATCH_NEEDED"] = "0" | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
import cudf | ||
|
||
del cudf | ||
Comment on lines
+16
to
+31
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This avoids ugly error messages when importing on a too-old driver. |
||
|
||
# Check we have a supported polars version | ||
import cudf_polars.utils.versions as v # noqa: E402 | ||
from cudf_polars._version import __git_commit__, __version__ # noqa: E402 | ||
from cudf_polars.callback import execute_with_cudf # noqa: E402 | ||
from cudf_polars.dsl.translate import translate_ir # noqa: E402 | ||
|
||
del v | ||
|
||
__all__: list[str] = [ | ||
"execute_with_cudf", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -625,7 +625,7 @@ class Join(IR): | |
right_on: list[expr.NamedExpr] | ||
"""List of expressions used as keys in the right frame.""" | ||
options: tuple[ | ||
Literal["inner", "left", "full", "leftsemi", "leftanti", "cross"], | ||
Literal["inner", "left", "right", "full", "leftsemi", "leftanti", "cross"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added support for the new right join style, since it was not much more effort than raising as an unsupported type. |
||
bool, | ||
tuple[int, int] | None, | ||
str | None, | ||
|
@@ -651,7 +651,7 @@ def __post_init__(self) -> None: | |
@staticmethod | ||
@cache | ||
def _joiners( | ||
how: Literal["inner", "left", "full", "leftsemi", "leftanti"], | ||
how: Literal["inner", "left", "right", "full", "leftsemi", "leftanti"], | ||
) -> tuple[ | ||
Callable, plc.copying.OutOfBoundsPolicy, plc.copying.OutOfBoundsPolicy | None | ||
]: | ||
|
@@ -661,7 +661,7 @@ def _joiners( | |
plc.copying.OutOfBoundsPolicy.DONT_CHECK, | ||
plc.copying.OutOfBoundsPolicy.DONT_CHECK, | ||
) | ||
elif how == "left": | ||
elif how == "left" or how == "right": | ||
return ( | ||
plc.join.left_join, | ||
plc.copying.OutOfBoundsPolicy.DONT_CHECK, | ||
|
@@ -685,8 +685,7 @@ def _joiners( | |
plc.copying.OutOfBoundsPolicy.DONT_CHECK, | ||
None, | ||
) | ||
else: | ||
assert_never(how) | ||
assert_never(how) | ||
|
||
def _reorder_maps( | ||
self, | ||
|
@@ -780,8 +779,12 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: | |
table = plc.copying.gather(left.table, lg, left_policy) | ||
result = DataFrame.from_table(table, left.column_names) | ||
else: | ||
if how == "right": | ||
# Right join is a left join with the tables swapped | ||
left, right = right, left | ||
left_on, right_on = right_on, left_on | ||
lg, rg = join_fn(left_on.table, right_on.table, null_equality) | ||
if how == "left": | ||
if how == "left" or how == "right": | ||
# Order of left table is preserved | ||
lg, rg = self._reorder_maps( | ||
left.num_rows, lg, left_policy, right.num_rows, rg, right_policy | ||
|
@@ -808,6 +811,9 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: | |
) | ||
) | ||
right = right.discard_columns(right_on.column_names_set) | ||
if how == "right": | ||
# Undo the swap for right join before gluing together. | ||
left, right = right, left | ||
right = right.rename_columns( | ||
{ | ||
name: f"{name}{suffix}" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Polars 1.3 is the version that has the public collect UX, hence defines the minimum version we support.