Skip to content

Commit

Permalink
Use new GPUEngine config object to set things up
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Jul 22, 2024
1 parent 852b151 commit 31533a1
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 13 deletions.
120 changes: 113 additions & 7 deletions python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,135 @@

from __future__ import annotations

from functools import partial
import contextlib
from functools import cache, partial
from typing import TYPE_CHECKING

import nvtx

import rmm
from rmm._cuda import gpu

from cudf_polars.dsl.translate import translate_ir

if TYPE_CHECKING:
from collections.abc import Generator

import polars as pl
from polars import GPUEngine

from cudf_polars.dsl.ir import IR
from cudf_polars.typing import NodeTraverser

__all__: list[str] = ["execute_with_cudf"]


@cache
def default_memory_resource(device: int) -> rmm.mr.DeviceMemoryResource:
"""
Return the default memory resource for cudf-polars.
Parameters
----------
device
Disambiguating device id when selecting the device. Must be
the active device when this function is called.
Returns
-------
rmm.mr.DeviceMemoryResource
The default memory resource that cudf-polars uses. Currently
an async pool resource.
"""
return rmm.mr.CudaAsyncMemoryResource()


@contextlib.contextmanager
def set_memory_resource(
mr: rmm.mr.DeviceMemoryResource | None,
) -> Generator[rmm.mr.DeviceMemoryResource, None, None]:
"""
Set the current memory resource for an execution block.
Parameters
----------
mr
Memory resource to use. If `None`, calls :func:`default_memory_resource`
to obtain an mr on the currently active device.
Returns
-------
Memory resource used.
Notes
-----
At exit, the memory resource is restored to whatever was current
at entry. If a memory resource is provided, it must be valid to
use with the currently active device.
"""
if mr is None:
device: int = gpu.getDevice()
mr = default_memory_resource(device)
previous = rmm.mr.get_current_device_resource()
rmm.mr.set_current_device_resource(mr)
try:
yield mr
finally:
rmm.mr.set_current_device_resource(previous)


@contextlib.contextmanager
def set_device(device: int | None) -> Generator[int, None, None]:
"""
Set the device the query is executed on.
Parameters
----------
device
Device to use. If `None`, uses the current device.
Returns
-------
Device active for the execution of the block.
Notes
-----
At exit, the device is restored to whatever was current at entry.
"""
previous: int = gpu.getDevice()
if device is not None:
gpu.setDevice(device)
try:
yield previous
finally:
gpu.setDevice(previous)


def _callback(
ir: IR,
with_columns: list[str] | None,
pyarrow_predicate: str | None,
n_rows: int | None,
*,
device: int | None,
memory_resource: int | None,
) -> pl.DataFrame:
assert with_columns is None
assert pyarrow_predicate is None
assert n_rows is None
with nvtx.annotate(message="ExecuteIR", domain="cudf_polars"):
with (
nvtx.annotate(message="ExecuteIR", domain="cudf_polars"),
# Device must be set before memory resource is obtained.
set_device(device),
set_memory_resource(memory_resource),
):
return ir.evaluate(cache={}).to_polars()


def execute_with_cudf(
nt: NodeTraverser,
*,
raise_on_fail: bool = False,
config: GPUEngine,
exception: type[Exception] | tuple[type[Exception], ...] = Exception,
) -> None:
"""
Expand All @@ -48,19 +144,29 @@ def execute_with_cudf(
nt
NodeTraverser
raise_on_fail
Should conversion raise an exception rather than continuing
without setting a callback.
config
GPUEngine configuration object
exception
Optional exception, or tuple of exceptions, to catch during
translation. Defaults to ``Exception``.
The NodeTraverser is mutated if the libcudf executor can handle the plan.
"""
device = config.device
memory_resource = config.memory_resource
raise_on_fail = config.config.get("raise_on_fail", False)

try:
with nvtx.annotate(message="ConvertIR", domain="cudf_polars"):
nt.set_udf(partial(_callback, translate_ir(nt)))
nt.set_udf(
partial(
_callback,
translate_ir(nt),
device=device,
memory_resource=memory_resource,
)
)
except exception:
if raise_on_fail:
raise
9 changes: 3 additions & 6 deletions python/cudf_polars/cudf_polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

from polars import GPUEngine
from polars.testing.asserts import assert_frame_equal

from cudf_polars.callback import execute_with_cudf
from cudf_polars.dsl.translate import translate_ir

if TYPE_CHECKING:
Expand Down Expand Up @@ -70,10 +69,8 @@ def assert_gpu_result_equal(
"""
collect_kwargs = {} if collect_kwargs is None else collect_kwargs
expect = lazydf.collect(**collect_kwargs)
got = lazydf.collect(
**collect_kwargs,
post_opt_callback=partial(execute_with_cudf, raise_on_fail=True),
)
engine = GPUEngine(raise_on_fail=True)
got = lazydf.collect(**collect_kwargs, engine=engine)
assert_frame_equal(
expect,
got,
Expand Down

0 comments on commit 31533a1

Please sign in to comment.