Skip to content

Commit

Permalink
Merge pull request #16347 from wence-/wence/fea/polars-engine-config
Browse files Browse the repository at this point in the history
Use new polars engine config object in cudf-polars callback
  • Loading branch information
wence- authored Aug 2, 2024
2 parents 445a75f + abcf22b commit 62a5dbd
Show file tree
Hide file tree
Showing 17 changed files with 334 additions and 116 deletions.
2 changes: 1 addition & 1 deletion dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ dependencies:
common:
- output_types: [conda, requirements, pyproject]
packages:
- polars>=1.0
- polars>=1.3
run_dask_cudf:
common:
- output_types: [conda, requirements, pyproject]
Expand Down
30 changes: 27 additions & 3 deletions python/cudf_polars/cudf_polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,33 @@

from __future__ import annotations

from cudf_polars._version import __git_commit__, __version__
from cudf_polars.callback import execute_with_cudf
from cudf_polars.dsl.translate import translate_ir
import os
import warnings

# We want to avoid initialising the GPU on import. Unfortunately,
# while we still depend on cudf, the default mode is to check things.
# If we set RAPIDS_NO_INITIALIZE, then cudf doesn't do import-time
# validation, good.
# We additionally must set the ptxcompiler environment variable, so
# that we don't check if a numba patch is needed. But if this is done,
# then the patching mechanism warns, and we want to squash that
# warning too.
# TODO: Remove this when we only depend on a pylibcudf package.
os.environ["RAPIDS_NO_INITIALIZE"] = "1"
os.environ["PTXCOMPILER_CHECK_NUMBA_CODEGEN_PATCH_NEEDED"] = "0"
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import cudf

del cudf

# Check we have a supported polars version
import cudf_polars.utils.versions as v # noqa: E402
from cudf_polars._version import __git_commit__, __version__ # noqa: E402
from cudf_polars.callback import execute_with_cudf # noqa: E402
from cudf_polars.dsl.translate import translate_ir # noqa: E402

del v

__all__: list[str] = [
"execute_with_cudf",
Expand Down
139 changes: 131 additions & 8 deletions python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,153 @@

from __future__ import annotations

import contextlib
import os
import warnings
from functools import partial
from functools import cache, partial
from typing import TYPE_CHECKING

import nvtx

from polars.exceptions import PerformanceWarning
from polars.exceptions import ComputeError, PerformanceWarning

import rmm
from rmm._cuda import gpu

from cudf_polars.dsl.translate import translate_ir

if TYPE_CHECKING:
from collections.abc import Generator

import polars as pl
from polars import GPUEngine

from cudf_polars.dsl.ir import IR
from cudf_polars.typing import NodeTraverser

__all__: list[str] = ["execute_with_cudf"]


@cache
def default_memory_resource(device: int) -> rmm.mr.DeviceMemoryResource:
"""
Return the default memory resource for cudf-polars.
Parameters
----------
device
Disambiguating device id when selecting the device. Must be
the active device when this function is called.
Returns
-------
rmm.mr.DeviceMemoryResource
The default memory resource that cudf-polars uses. Currently
an async pool resource.
"""
try:
return rmm.mr.CudaAsyncMemoryResource()
except RuntimeError as e: # pragma: no cover
msg, *_ = e.args
if (
msg.startswith("RMM failure")
and msg.find("not supported with this CUDA driver/runtime version") > -1
):
raise ComputeError(
"GPU engine requested, but incorrect cudf-polars package installed. "
"If your system has a CUDA 11 driver, please uninstall `cudf-polars-cu12` "
"and install `cudf-polars-cu11`"
) from None
else:
raise


@contextlib.contextmanager
def set_memory_resource(
mr: rmm.mr.DeviceMemoryResource | None,
) -> Generator[rmm.mr.DeviceMemoryResource, None, None]:
"""
Set the current memory resource for an execution block.
Parameters
----------
mr
Memory resource to use. If `None`, calls :func:`default_memory_resource`
to obtain an mr on the currently active device.
Returns
-------
Memory resource used.
Notes
-----
At exit, the memory resource is restored to whatever was current
at entry. If a memory resource is provided, it must be valid to
use with the currently active device.
"""
if mr is None:
device: int = gpu.getDevice()
mr = default_memory_resource(device)
previous = rmm.mr.get_current_device_resource()
rmm.mr.set_current_device_resource(mr)
try:
yield mr
finally:
rmm.mr.set_current_device_resource(previous)


@contextlib.contextmanager
def set_device(device: int | None) -> Generator[int, None, None]:
"""
Set the device the query is executed on.
Parameters
----------
device
Device to use. If `None`, uses the current device.
Returns
-------
Device active for the execution of the block.
Notes
-----
At exit, the device is restored to whatever was current at entry.
"""
previous: int = gpu.getDevice()
if device is not None:
gpu.setDevice(device)
try:
yield previous
finally:
gpu.setDevice(previous)


def _callback(
ir: IR,
with_columns: list[str] | None,
pyarrow_predicate: str | None,
n_rows: int | None,
*,
device: int | None,
memory_resource: int | None,
) -> pl.DataFrame:
assert with_columns is None
assert pyarrow_predicate is None
assert n_rows is None
with nvtx.annotate(message="ExecuteIR", domain="cudf_polars"):
with (
nvtx.annotate(message="ExecuteIR", domain="cudf_polars"),
# Device must be set before memory resource is obtained.
set_device(device),
set_memory_resource(memory_resource),
):
return ir.evaluate(cache={}).to_polars()


def execute_with_cudf(
nt: NodeTraverser,
*,
raise_on_fail: bool = False,
config: GPUEngine,
exception: type[Exception] | tuple[type[Exception], ...] = Exception,
) -> None:
"""
Expand All @@ -52,19 +162,32 @@ def execute_with_cudf(
nt
NodeTraverser
raise_on_fail
Should conversion raise an exception rather than continuing
without setting a callback.
config
GPUEngine configuration object
exception
Optional exception, or tuple of exceptions, to catch during
translation. Defaults to ``Exception``.
The NodeTraverser is mutated if the libcudf executor can handle the plan.
"""
device = config.device
memory_resource = config.memory_resource
raise_on_fail = config.config.get("raise_on_fail", False)
if unsupported := (config.config.keys() - {"raise_on_fail"}):
raise ValueError(
f"Engine configuration contains unsupported settings {unsupported}"
)
try:
with nvtx.annotate(message="ConvertIR", domain="cudf_polars"):
nt.set_udf(partial(_callback, translate_ir(nt)))
nt.set_udf(
partial(
_callback,
translate_ir(nt),
device=device,
memory_resource=memory_resource,
)
)
except exception as e:
if bool(int(os.environ.get("POLARS_VERBOSE", 0))):
warnings.warn(
Expand Down
8 changes: 4 additions & 4 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,9 +885,9 @@ def __init__(
if self.name not in (
"mask_nans",
"round",
"setsorted",
"set_sorted",
"unique",
"dropnull",
"drop_nulls",
"fill_null",
):
raise NotImplementedError(f"Unary function {name=}")
Expand Down Expand Up @@ -948,7 +948,7 @@ def do_evaluate(
if maintain_order:
return Column(column).sorted_like(values)
return Column(column)
elif self.name == "setsorted":
elif self.name == "set_sorted":
(column,) = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
Expand All @@ -975,7 +975,7 @@ def do_evaluate(
order=order,
null_order=null_order,
)
elif self.name == "dropnull":
elif self.name == "drop_nulls":
(column,) = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
Expand Down
18 changes: 12 additions & 6 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ class Join(IR):
right_on: list[expr.NamedExpr]
"""List of expressions used as keys in the right frame."""
options: tuple[
Literal["inner", "left", "full", "leftsemi", "leftanti", "cross"],
Literal["inner", "left", "right", "full", "leftsemi", "leftanti", "cross"],
bool,
tuple[int, int] | None,
str | None,
Expand All @@ -651,7 +651,7 @@ def __post_init__(self) -> None:
@staticmethod
@cache
def _joiners(
how: Literal["inner", "left", "full", "leftsemi", "leftanti"],
how: Literal["inner", "left", "right", "full", "leftsemi", "leftanti"],
) -> tuple[
Callable, plc.copying.OutOfBoundsPolicy, plc.copying.OutOfBoundsPolicy | None
]:
Expand All @@ -661,7 +661,7 @@ def _joiners(
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
)
elif how == "left":
elif how == "left" or how == "right":
return (
plc.join.left_join,
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
Expand All @@ -685,8 +685,7 @@ def _joiners(
plc.copying.OutOfBoundsPolicy.DONT_CHECK,
None,
)
else:
assert_never(how)
assert_never(how)

def _reorder_maps(
self,
Expand Down Expand Up @@ -780,8 +779,12 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
table = plc.copying.gather(left.table, lg, left_policy)
result = DataFrame.from_table(table, left.column_names)
else:
if how == "right":
# Right join is a left join with the tables swapped
left, right = right, left
left_on, right_on = right_on, left_on
lg, rg = join_fn(left_on.table, right_on.table, null_equality)
if how == "left":
if how == "left" or how == "right":
# Order of left table is preserved
lg, rg = self._reorder_maps(
left.num_rows, lg, left_policy, right.num_rows, rg, right_policy
Expand All @@ -808,6 +811,9 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
)
)
right = right.discard_columns(right_on.column_names_set)
if how == "right":
# Undo the swap for right join before gluing together.
left, right = right, left
right = right.rename_columns(
{
name: f"{name}{suffix}"
Expand Down
36 changes: 29 additions & 7 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,26 @@ def _translate_ir(
def _(
node: pl_ir.PythonScan, visitor: NodeTraverser, schema: dict[str, plc.DataType]
) -> ir.IR:
return ir.PythonScan(
schema,
node.options,
translate_named_expr(visitor, n=node.predicate)
if node.predicate is not None
else None,
)
if visitor.version()[0] == 1: # pragma: no cover
# https://github.com/pola-rs/polars/pull/17939
# Versioning can be dropped once polars 1.4 is lowest
# supported version.
scan_fn, with_columns, source_type, predicate, nrows = node.options
options = (scan_fn, with_columns, source_type, nrows)
predicate = (
translate_named_expr(visitor, n=predicate)
if predicate is not None
else None
)
else:
# version == 0
options = node.options
predicate = (
translate_named_expr(visitor, n=node.predicate)
if node.predicate is not None
else None
)
return ir.PythonScan(schema, options, predicate)


@_translate_ir.register
Expand Down Expand Up @@ -294,6 +307,15 @@ def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR:
ctx: AbstractContextManager[None] = (
set_node(visitor, n) if n is not None else noop_context
)
# IR is versioned with major.minor, minor is bumped for backwards
# compatible changes (e.g. adding new nodes), major is bumped for
# incompatible changes (e.g. renaming nodes).
# Polars 1.4 changes definition of PythonScan.
if (version := visitor.version()) >= (2, 0):
raise NotImplementedError(
f"No support for polars IR {version=}"
) # pragma: no cover; no such version for now.

with ctx:
node = visitor.view_current_node()
schema = {k: dtypes.from_polars(v) for k, v in visitor.get_schema().items()}
Expand Down
Loading

0 comments on commit 62a5dbd

Please sign in to comment.