From f05e89db8f6750232a452d072fa9f9ea988a6b34 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Mon, 25 Nov 2024 13:03:54 -0600 Subject: [PATCH] Single-partition Dask executor for cuDF-Polars (#17262) The goal here is to lay down the initial foundation for dask-based evaluation of `IR` graphs in cudf-polars. The first pass will only support single-partition workloads. This functionality could be achieved with much less-complicated changes to cudf-polars. However, we **do** want to build multi-partition support on top of this. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Lawrence Mitchell (https://github.com/wence-) - James Lamb (https://github.com/jameslamb) URL: https://github.com/rapidsai/cudf/pull/17262 --- ci/run_cudf_polars_pytests.sh | 4 + python/cudf_polars/cudf_polars/callback.py | 18 +- python/cudf_polars/cudf_polars/dsl/ir.py | 25 +- .../cudf_polars/cudf_polars/dsl/translate.py | 3 +- .../cudf_polars/experimental/parallel.py | 236 ++++++++++++++++++ .../cudf_polars/testing/asserts.py | 11 +- python/cudf_polars/tests/conftest.py | 16 ++ .../tests/experimental/test_parallel.py | 21 ++ python/cudf_polars/tests/test_executors.py | 68 +++++ 9 files changed, 388 insertions(+), 14 deletions(-) create mode 100644 python/cudf_polars/cudf_polars/experimental/parallel.py create mode 100644 python/cudf_polars/tests/experimental/test_parallel.py create mode 100644 python/cudf_polars/tests/test_executors.py diff --git a/ci/run_cudf_polars_pytests.sh b/ci/run_cudf_polars_pytests.sh index c10612a065a..bf5a3ccee8e 100755 --- a/ci/run_cudf_polars_pytests.sh +++ b/ci/run_cudf_polars_pytests.sh @@ -8,4 +8,8 @@ set -euo pipefail # Support invoking run_cudf_polars_pytests.sh outside the script directory cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../python/cudf_polars/ +# Test the default "cudf" executor python -m pytest --cache-clear "$@" tests + +# Test the "dask-experimental" executor +python -m pytest --cache-clear "$@" tests --executor dask-experimental diff --git a/python/cudf_polars/cudf_polars/callback.py b/python/cudf_polars/cudf_polars/callback.py index 8dc5715195d..95527028aa9 100644 --- a/python/cudf_polars/cudf_polars/callback.py +++ b/python/cudf_polars/cudf_polars/callback.py @@ -9,7 +9,7 @@ import os import warnings from functools import cache, partial -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import nvtx @@ -181,6 +181,7 @@ def _callback( *, device: int | None, memory_resource: int | None, + executor: Literal["pylibcudf", "dask-experimental"] | None, ) -> pl.DataFrame: assert with_columns is None assert pyarrow_predicate is None @@ -191,7 +192,14 @@ def _callback( set_device(device), set_memory_resource(memory_resource), ): - return ir.evaluate(cache={}).to_polars() + if executor is None or executor == "pylibcudf": + return ir.evaluate(cache={}).to_polars() + elif executor == "dask-experimental": + from cudf_polars.experimental.parallel import evaluate_dask + + return evaluate_dask(ir).to_polars() + else: + raise ValueError(f"Unknown executor '{executor}'") def validate_config_options(config: dict) -> None: @@ -208,7 +216,9 @@ def validate_config_options(config: dict) -> None: ValueError If the configuration contains unsupported options. """ - if unsupported := (config.keys() - {"raise_on_fail", "parquet_options"}): + if unsupported := ( + config.keys() - {"raise_on_fail", "parquet_options", "executor"} + ): raise ValueError( f"Engine configuration contains unsupported settings: {unsupported}" ) @@ -243,6 +253,7 @@ def execute_with_cudf(nt: NodeTraverser, *, config: GPUEngine) -> None: device = config.device memory_resource = config.memory_resource raise_on_fail = config.config.get("raise_on_fail", False) + executor = config.config.get("executor", None) validate_config_options(config.config) with nvtx.annotate(message="ConvertIR", domain="cudf_polars"): @@ -272,5 +283,6 @@ def execute_with_cudf(nt: NodeTraverser, *, config: GPUEngine) -> None: ir, device=device, memory_resource=memory_resource, + executor=executor, ) ) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 62a2da9dcea..6617b71be81 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -1599,13 +1599,15 @@ def __init__(self, schema: Schema, name: str, options: Any, df: IR): # polars requires that all to-explode columns have the # same sub-shapes raise NotImplementedError("Explode with more than one column") + self.options = (tuple(to_explode),) elif self.name == "rename": - old, new, _ = self.options + old, new, strict = self.options # TODO: perhaps polars should validate renaming in the IR? if len(new) != len(set(new)) or ( set(new) & (set(df.schema.keys()) - set(old)) ): raise NotImplementedError("Duplicate new names in rename.") + self.options = (tuple(old), tuple(new), strict) elif self.name == "unpivot": indices, pivotees, variable_name, value_name = self.options value_name = "value" if value_name is None else value_name @@ -1623,13 +1625,15 @@ def __init__(self, schema: Schema, name: str, options: Any, df: IR): self.options = ( tuple(indices), tuple(pivotees), - (variable_name, schema[variable_name]), - (value_name, schema[value_name]), + variable_name, + value_name, ) - self._non_child_args = (name, self.options) + self._non_child_args = (schema, name, self.options) @classmethod - def do_evaluate(cls, name: str, options: Any, df: DataFrame) -> DataFrame: + def do_evaluate( + cls, schema: Schema, name: str, options: Any, df: DataFrame + ) -> DataFrame: """Evaluate and return a dataframe.""" if name == "rechunk": # No-op in our data model @@ -1651,8 +1655,8 @@ def do_evaluate(cls, name: str, options: Any, df: DataFrame) -> DataFrame: ( indices, pivotees, - (variable_name, variable_dtype), - (value_name, value_dtype), + variable_name, + value_name, ) = options npiv = len(pivotees) index_columns = [ @@ -1669,7 +1673,7 @@ def do_evaluate(cls, name: str, options: Any, df: DataFrame) -> DataFrame: plc.interop.from_arrow( pa.array( pivotees, - type=plc.interop.to_arrow(variable_dtype), + type=plc.interop.to_arrow(schema[variable_name]), ), ) ] @@ -1677,7 +1681,10 @@ def do_evaluate(cls, name: str, options: Any, df: DataFrame) -> DataFrame: df.num_rows, ).columns() value_column = plc.concatenate.concatenate( - [df.column_map[pivotee].astype(value_dtype).obj for pivotee in pivotees] + [ + df.column_map[pivotee].astype(schema[value_name]).obj + for pivotee in pivotees + ] ) return DataFrame( [ diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 12fc2a196cd..9480ce6e535 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -633,9 +633,10 @@ def _(node: pl_expr.Sort, translator: Translator, dtype: plc.DataType) -> expr.E @_translate_expr.register def _(node: pl_expr.SortBy, translator: Translator, dtype: plc.DataType) -> expr.Expr: + options = node.sort_options return expr.SortBy( dtype, - node.sort_options, + (options[0], tuple(options[1]), tuple(options[2])), translator.translate_expr(n=node.expr), *(translator.translate_expr(n=n) for n in node.by), ) diff --git a/python/cudf_polars/cudf_polars/experimental/parallel.py b/python/cudf_polars/cudf_polars/experimental/parallel.py new file mode 100644 index 00000000000..6518dd60c7d --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/parallel.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""Partitioned LogicalPlan nodes.""" + +from __future__ import annotations + +import operator +from functools import reduce, singledispatch +from typing import TYPE_CHECKING, Any + +from cudf_polars.dsl.ir import IR +from cudf_polars.dsl.traversal import traversal + +if TYPE_CHECKING: + from collections.abc import MutableMapping + from typing import TypeAlias + + from cudf_polars.containers import DataFrame + from cudf_polars.dsl.nodebase import Node + from cudf_polars.typing import GenericTransformer + + +class PartitionInfo: + """ + Partitioning information. + + This class only tracks the partition count (for now). + """ + + __slots__ = ("count",) + + def __init__(self, count: int): + self.count = count + + +LowerIRTransformer: TypeAlias = ( + "GenericTransformer[IR, MutableMapping[IR, PartitionInfo]]" +) +"""Protocol for Lowering IR nodes.""" + + +def get_key_name(node: Node) -> str: + """Generate the key name for a Node.""" + return f"{type(node).__name__.lower()}-{hash(node)}" + + +@singledispatch +def lower_ir_node( + ir: IR, rec: LowerIRTransformer +) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + """ + Rewrite an IR node and extract partitioning information. + + Parameters + ---------- + ir + IR node to rewrite. + rec + Recursive LowerIRTransformer callable. + + Returns + ------- + new_ir, partition_info + The rewritten node, and a mapping from unique nodes in + the full IR graph to associated partitioning information. + + Notes + ----- + This function is used by `lower_ir_graph`. + + See Also + -------- + lower_ir_graph + """ + raise AssertionError(f"Unhandled type {type(ir)}") # pragma: no cover + + +@lower_ir_node.register(IR) +def _(ir: IR, rec: LowerIRTransformer) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + if len(ir.children) == 0: + # Default leaf node has single partition + return ir, {ir: PartitionInfo(count=1)} + + # Lower children + children, _partition_info = zip(*(rec(c) for c in ir.children), strict=False) + partition_info = reduce(operator.or_, _partition_info) + + # Check that child partitioning is supported + count = max(partition_info[c].count for c in children) + if count > 1: + raise NotImplementedError( + f"Class {type(ir)} does not support multiple partitions." + ) # pragma: no cover + + # Return reconstructed node and partition-info dict + partition = PartitionInfo(count=1) + new_node = ir.reconstruct(children) + partition_info[new_node] = partition + return new_node, partition_info + + +def lower_ir_graph(ir: IR) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: + """ + Rewrite an IR graph and extract partitioning information. + + Parameters + ---------- + ir + Root of the graph to rewrite. + + Returns + ------- + new_ir, partition_info + The rewritten graph, and a mapping from unique nodes + in the new graph to associated partitioning information. + + Notes + ----- + This function traverses the unique nodes of the graph with + root `ir`, and applies :func:`lower_ir_node` to each node. + + See Also + -------- + lower_ir_node + """ + from cudf_polars.dsl.traversal import CachingVisitor + + mapper = CachingVisitor(lower_ir_node) + return mapper(ir) + + +@singledispatch +def generate_ir_tasks( + ir: IR, partition_info: MutableMapping[IR, PartitionInfo] +) -> MutableMapping[Any, Any]: + """ + Generate a task graph for evaluation of an IR node. + + Parameters + ---------- + ir + IR node to generate tasks for. + partition_info + Partitioning information, obtained from :func:`lower_ir_graph`. + + Returns + ------- + mapping + A (partial) dask task graph for the evaluation of an ir node. + + Notes + ----- + Task generation should only produce the tasks for the current node, + referring to child tasks by name. + + See Also + -------- + task_graph + """ + raise AssertionError(f"Unhandled type {type(ir)}") # pragma: no cover + + +@generate_ir_tasks.register(IR) +def _( + ir: IR, partition_info: MutableMapping[IR, PartitionInfo] +) -> MutableMapping[Any, Any]: + # Single-partition default behavior. + # This is used by `generate_ir_tasks` for all unregistered IR sub-types. + if partition_info[ir].count > 1: + raise NotImplementedError( + f"Failed to generate multiple output tasks for {ir}." + ) # pragma: no cover + + child_names = [] + for child in ir.children: + child_names.append(get_key_name(child)) + if partition_info[child].count > 1: + raise NotImplementedError( + f"Failed to generate tasks for {ir} with child {child}." + ) # pragma: no cover + + key_name = get_key_name(ir) + return { + (key_name, 0): ( + ir.do_evaluate, + *ir._non_child_args, + *((child_name, 0) for child_name in child_names), + ) + } + + +def task_graph( + ir: IR, partition_info: MutableMapping[IR, PartitionInfo] +) -> tuple[MutableMapping[Any, Any], str | tuple[str, int]]: + """ + Construct a task graph for evaluation of an IR graph. + + Parameters + ---------- + ir + Root of the graph to rewrite. + partition_info + A mapping from all unique IR nodes to the + associated partitioning information. + + Returns + ------- + graph + A Dask-compatible task graph for the entire + IR graph with root `ir`. + + Notes + ----- + This function traverses the unique nodes of the + graph with root `ir`, and extracts the tasks for + each node with :func:`generate_ir_tasks`. + + See Also + -------- + generate_ir_tasks + """ + graph = reduce( + operator.or_, + (generate_ir_tasks(node, partition_info) for node in traversal(ir)), + ) + return graph, (get_key_name(ir), 0) + + +def evaluate_dask(ir: IR) -> DataFrame: + """Evaluate an IR graph with Dask.""" + from dask import get + + ir, partition_info = lower_ir_graph(ir) + + graph, key = task_graph(ir, partition_info) + return get(graph, key) diff --git a/python/cudf_polars/cudf_polars/testing/asserts.py b/python/cudf_polars/cudf_polars/testing/asserts.py index ba0bb12a0fb..d986f150b2e 100644 --- a/python/cudf_polars/cudf_polars/testing/asserts.py +++ b/python/cudf_polars/cudf_polars/testing/asserts.py @@ -20,6 +20,11 @@ __all__: list[str] = ["assert_gpu_result_equal", "assert_ir_translation_raises"] +# Will be overriden by `conftest.py` with the value from the `--executor` +# command-line argument +Executor = None + + def assert_gpu_result_equal( lazydf: pl.LazyFrame, *, @@ -34,6 +39,7 @@ def assert_gpu_result_equal( rtol: float = 1e-05, atol: float = 1e-08, categorical_as_str: bool = False, + executor: str | None = None, ) -> None: """ Assert that collection of a lazyframe on GPU produces correct results. @@ -71,6 +77,9 @@ def assert_gpu_result_equal( Absolute tolerance for float comparisons categorical_as_str Decat categoricals to strings before comparing + executor + The executor configuration to pass to `GPUEngine`. If not specified + uses the module level `Executor` attribute. Raises ------ @@ -80,7 +89,7 @@ def assert_gpu_result_equal( If GPU collection failed in some way. """ if engine is None: - engine = GPUEngine(raise_on_fail=True) + engine = GPUEngine(raise_on_fail=True, executor=executor or Executor) final_polars_collect_kwargs, final_cudf_collect_kwargs = _process_kwargs( collect_kwargs, polars_collect_kwargs, cudf_collect_kwargs diff --git a/python/cudf_polars/tests/conftest.py b/python/cudf_polars/tests/conftest.py index 9bbce6bc080..6338bf0cae1 100644 --- a/python/cudf_polars/tests/conftest.py +++ b/python/cudf_polars/tests/conftest.py @@ -8,3 +8,19 @@ @pytest.fixture(params=[False, True], ids=["no_nulls", "nulls"], scope="session") def with_nulls(request): return request.param + + +def pytest_addoption(parser): + parser.addoption( + "--executor", + action="store", + default="pylibcudf", + choices=("pylibcudf", "dask-experimental"), + help="Executor to use for GPUEngine.", + ) + + +def pytest_configure(config): + import cudf_polars.testing.asserts + + cudf_polars.testing.asserts.Executor = config.getoption("--executor") diff --git a/python/cudf_polars/tests/experimental/test_parallel.py b/python/cudf_polars/tests/experimental/test_parallel.py new file mode 100644 index 00000000000..d46ab88eebf --- /dev/null +++ b/python/cudf_polars/tests/experimental/test_parallel.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import polars as pl +from polars import GPUEngine +from polars.testing import assert_frame_equal + + +def test_evaluate_dask(): + df = pl.LazyFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7], "d": [7, 9, 8]}) + q = df.select(pl.col("a") - (pl.col("b") + pl.col("c") * 2), pl.col("d")).sort("d") + + expected = q.collect(engine="cpu") + got_gpu = q.collect(engine=GPUEngine(raise_on_fail=True)) + got_dask = q.collect( + engine=GPUEngine(raise_on_fail=True, executor="dask-experimental") + ) + assert_frame_equal(expected, got_gpu) + assert_frame_equal(expected, got_dask) diff --git a/python/cudf_polars/tests/test_executors.py b/python/cudf_polars/tests/test_executors.py new file mode 100644 index 00000000000..3eaea2ec9ea --- /dev/null +++ b/python/cudf_polars/tests/test_executors.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.mark.parametrize("executor", [None, "pylibcudf", "dask-experimental"]) +def test_executor_basics(executor): + if executor == "dask-experimental": + pytest.importorskip("dask") + + df = pl.LazyFrame( + { + "a": pl.Series([[1, 2], [3]], dtype=pl.List(pl.Int8())), + "b": pl.Series([[1], [2]], dtype=pl.List(pl.UInt16())), + "c": pl.Series( + [ + [["1", "2", "3"], ["4", "567"]], + [["8", "9"], []], + ], + dtype=pl.List(pl.List(pl.String())), + ), + "d": pl.Series([[[1, 2]], []], dtype=pl.List(pl.List(pl.UInt16()))), + } + ) + + assert_gpu_result_equal(df, executor=executor) + + +def test_cudf_cache_evaluate(): + ldf = pl.DataFrame( + { + "a": [1, 2, 3, 4, 5, 6, 7], + "b": [1, 1, 1, 1, 1, 1, 1], + } + ).lazy() + ldf2 = ldf.select((pl.col("a") + pl.col("b")).alias("c"), pl.col("a")) + query = pl.concat([ldf, ldf2], how="diagonal") + assert_gpu_result_equal(query, executor="pylibcudf") + + +def test_dask_experimental_map_function_get_hashable(): + df = pl.LazyFrame( + { + "a": pl.Series([11, 12, 13], dtype=pl.UInt16), + "b": pl.Series([1, 3, 5], dtype=pl.Int16), + "c": pl.Series([2, 4, 6], dtype=pl.Float32), + "d": ["a", "b", "c"], + } + ) + q = df.unpivot(index="d") + assert_gpu_result_equal(q, executor="dask-experimental") + + +def test_unknown_executor(): + df = pl.LazyFrame({}) + + with pytest.raises( + pl.exceptions.ComputeError, + match="ValueError: Unknown executor 'unknown-executor'", + ): + assert_gpu_result_equal(df, executor="unknown-executor")