Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partition-wise Select support to cuDF-Polars #17495

Merged
merged 17 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
self.dtype = dtype
self.name = name
self.options = options
self.is_pointwise = False
self.children = children
if name not in Agg._SUPPORTED:
raise NotImplementedError(
Expand Down
7 changes: 6 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ class ExecutionContext(IntEnum):
class Expr(Node["Expr"]):
"""An abstract expression object."""

__slots__ = ("dtype",)
__slots__ = ("dtype", "is_pointwise")
dtype: plc.DataType
"""Data type of the expression."""
is_pointwise: bool
"""Whether this expression acts pointwise on its inputs."""
# This annotation is needed because of https://github.com/python/mypy/issues/17981
_non_child: ClassVar[tuple[str, ...]] = ("dtype",)
"""Names of non-child data (not Exprs) for reconstruction."""
Expand Down Expand Up @@ -164,6 +166,7 @@ def __init__(self, dtype: plc.DataType, error: str) -> None:
self.dtype = dtype
self.error = error
self.children = ()
self.is_pointwise = True


class NamedExpr:
Expand Down Expand Up @@ -243,6 +246,7 @@ class Col(Expr):
def __init__(self, dtype: plc.DataType, name: str) -> None:
self.dtype = dtype
self.name = name
self.is_pointwise = True
self.children = ()

def do_evaluate(
Expand Down Expand Up @@ -280,6 +284,7 @@ def __init__(
self.dtype = dtype
self.index = index
self.table_ref = table_ref
self.is_pointwise = True
self.children = (column,)

def do_evaluate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
op = BinOp._BOOL_KLEENE_MAPPING.get(op, op)
self.op = op
self.children = (left, right)
self.is_pointwise = True
if not plc.binaryop.is_supported_operation(
self.dtype, left.dtype, right.dtype, op
):
Expand Down
8 changes: 8 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ def __init__(
self.options = options
self.name = name
self.children = children
self.is_pointwise = self.name not in (
BooleanFunction.Name.All,
BooleanFunction.Name.Any,
BooleanFunction.Name.IsDuplicated,
BooleanFunction.Name.IsFirstDistinct,
BooleanFunction.Name.IsLastDistinct,
BooleanFunction.Name.IsUnique,
)
if self.name is BooleanFunction.Name.IsIn and not all(
c.dtype == self.children[0].dtype for c in self.children
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
self.options = options
self.name = name
self.children = children
self.is_pointwise = True
if self.name not in self._COMPONENT_MAP:
raise NotImplementedError(f"Temporal function {self.name}")

Expand Down
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, dtype: plc.DataType, value: pa.Scalar[Any]) -> None:
assert value.type == plc.interop.to_arrow(dtype)
self.value = value
self.children = ()
self.is_pointwise = True

def do_evaluate(
self,
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(self, dtype: plc.DataType, value: pl.Series) -> None:
data = value.to_arrow()
self.value = data.cast(dtypes.downcast_arrow_lists(data.type))
self.children = ()
self.is_pointwise = True

def get_hashable(self) -> Hashable:
"""Compute a hash of the column."""
Expand Down
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr) -> None:
self.dtype = dtype
self.options = options
self.children = (agg,)
self.is_pointwise = False
raise NotImplementedError("Rolling window not implemented")


Expand All @@ -35,4 +36,5 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr) -> N
self.dtype = dtype
self.options = options
self.children = (agg, *by)
self.is_pointwise = False
raise NotImplementedError("Grouped rolling window not implemented")
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Gather(Expr):
def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr) -> None:
self.dtype = dtype
self.children = (values, indices)
self.is_pointwise = False

def do_evaluate(
self,
Expand Down Expand Up @@ -71,6 +72,7 @@ class Filter(Expr):
def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr):
self.dtype = dtype
self.children = (values, indices)
self.is_pointwise = True

def do_evaluate(
self,
Expand Down
2 changes: 2 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self.dtype = dtype
self.options = options
self.children = (column,)
self.is_pointwise = False

def do_evaluate(
self,
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(
self.dtype = dtype
self.options = options
self.children = (column, *by)
self.is_pointwise = False

def do_evaluate(
self,
Expand Down
1 change: 1 addition & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
self.options = options
self.name = name
self.children = children
self.is_pointwise = True
self._validate_input()

def _validate_input(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
) -> None:
self.dtype = dtype
self.children = (when, then, otherwise)
self.is_pointwise = True

def do_evaluate(
self,
Expand Down
10 changes: 10 additions & 0 deletions python/cudf_polars/cudf_polars/dsl/expressions/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Cast(Expr):
def __init__(self, dtype: plc.DataType, value: Expr) -> None:
self.dtype = dtype
self.children = (value,)
self.is_pointwise = True
if not dtypes.can_cast(value.dtype, self.dtype):
raise NotImplementedError(
f"Can't cast {value.dtype.id().name} to {self.dtype.id().name}"
Expand Down Expand Up @@ -63,6 +64,7 @@ class Len(Expr):
def __init__(self, dtype: plc.DataType) -> None:
self.dtype = dtype
self.children = ()
self.is_pointwise = False

def do_evaluate(
self,
Expand Down Expand Up @@ -147,6 +149,14 @@ def __init__(
self.name = name
self.options = options
self.children = children
self.is_pointwise = self.name not in (
"cum_min",
"cum_max",
"cum_prod",
"cum_sum",
"drop_nulls",
"unique",
)

if self.name not in UnaryFunction._supported_fns:
raise NotImplementedError(f"Unary function {name=}")
Expand Down
14 changes: 7 additions & 7 deletions python/cudf_polars/cudf_polars/dsl/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cudf_polars.typing import U_contra, V_co

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Mapping, MutableMapping
from collections.abc import Callable, Generator, Mapping, MutableMapping, Sequence

from cudf_polars.typing import GenericTransformer, NodeT

Expand All @@ -23,22 +23,22 @@
]


def traversal(node: NodeT) -> Generator[NodeT, None, None]:
def traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]:
"""
Pre-order traversal of nodes in an expression.

Parameters
----------
node
Root of expression to traverse.
nodes
Roots of expressions to traverse.

Yields
------
Unique nodes in the expression, parent before child, children
Unique nodes in the expressions, parent before child, children
in-order from left to right.
"""
seen = {node}
lifo = [node]
seen = set(nodes)
lifo = list(nodes)

while lifo:
node = lifo.pop()
Expand Down
12 changes: 9 additions & 3 deletions python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from functools import reduce
from typing import TYPE_CHECKING, Any

import cudf_polars.experimental.io # noqa: F401
from cudf_polars.dsl.ir import IR, Cache, Projection, Union
import cudf_polars.experimental.io
import cudf_polars.experimental.select # noqa: F401
from cudf_polars.dsl.ir import IR, Cache, Filter, HStack, Projection, Select, Union
from cudf_polars.dsl.traversal import CachingVisitor, traversal
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
from cudf_polars.experimental.dispatch import (
Expand Down Expand Up @@ -112,7 +113,7 @@ def task_graph(
"""
graph = reduce(
operator.or_,
(generate_ir_tasks(node, partition_info) for node in traversal(ir)),
(generate_ir_tasks(node, partition_info) for node in traversal([ir])),
)

key_name = get_key_name(ir)
Expand Down Expand Up @@ -226,6 +227,8 @@ def _lower_ir_pwise(

lower_ir_node.register(Projection, _lower_ir_pwise)
lower_ir_node.register(Cache, _lower_ir_pwise)
lower_ir_node.register(Filter, _lower_ir_pwise)
lower_ir_node.register(HStack, _lower_ir_pwise)


def _generate_ir_tasks_pwise(
Expand All @@ -245,3 +248,6 @@ def _generate_ir_tasks_pwise(

generate_ir_tasks.register(Projection, _generate_ir_tasks_pwise)
generate_ir_tasks.register(Cache, _generate_ir_tasks_pwise)
generate_ir_tasks.register(Filter, _generate_ir_tasks_pwise)
generate_ir_tasks.register(HStack, _generate_ir_tasks_pwise)
generate_ir_tasks.register(Select, _generate_ir_tasks_pwise)
36 changes: 36 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Parallel Select Logic."""

from __future__ import annotations

from typing import TYPE_CHECKING

from cudf_polars.dsl.ir import Select
from cudf_polars.dsl.traversal import traversal
from cudf_polars.experimental.dispatch import lower_ir_node

if TYPE_CHECKING:
from collections.abc import MutableMapping

from cudf_polars.dsl.ir import IR
from cudf_polars.experimental.base import PartitionInfo
from cudf_polars.experimental.parallel import LowerIRTransformer


@lower_ir_node.register(Select)
def _(
ir: Select, rec: LowerIRTransformer
) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
child, partition_info = rec(ir.children[0])
pi = partition_info[child]
if pi.count > 1 and not all(
expr.is_pointwise for expr in traversal([e.value for e in ir.exprs])
):
# TODO: Handle non-pointwise expressions.
raise NotImplementedError(
f"Selection {ir} does not support multiple partitions."
)
new_node = ir.reconstruct([child])
partition_info[new_node] = pi
return new_node, partition_info
6 changes: 3 additions & 3 deletions python/cudf_polars/tests/dsl/test_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@ def test_traversal_unique():
dt = plc.DataType(plc.TypeId.INT8)

e1 = make_expr(dt, "a", "a")
unique_exprs = list(traversal(e1))
unique_exprs = list(traversal([e1]))

assert len(unique_exprs) == 2
assert set(unique_exprs) == {expr.Col(dt, "a"), e1}
assert unique_exprs == [e1, expr.Col(dt, "a")]

e2 = make_expr(dt, "a", "b")
unique_exprs = list(traversal(e2))
unique_exprs = list(traversal([e2]))

assert len(unique_exprs) == 3
assert set(unique_exprs) == {expr.Col(dt, "a"), expr.Col(dt, "b"), e2}
assert unique_exprs == [e2, expr.Col(dt, "a"), expr.Col(dt, "b")]

e3 = make_expr(dt, "b", "a")
unique_exprs = list(traversal(e3))
unique_exprs = list(traversal([e3]))

assert len(unique_exprs) == 3
assert set(unique_exprs) == {expr.Col(dt, "a"), expr.Col(dt, "b"), e3}
Expand Down
54 changes: 54 additions & 0 deletions python/cudf_polars/tests/experimental/test_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import pytest

import polars as pl

from cudf_polars.testing.asserts import assert_gpu_result_equal


@pytest.fixture(scope="module")
def engine():
return pl.GPUEngine(
raise_on_fail=True,
executor="dask-experimental",
executor_options={"max_rows_per_partition": 3},
)


@pytest.fixture(scope="module")
def df():
return pl.LazyFrame(
{
"a": [1, 2, 3, 4, 5, 6, 7],
"b": [1, 1, 1, 1, 1, 1, 1],
}
)


def test_select(df, engine):
query = df.select(
pl.col("a") + pl.col("b"), (pl.col("a") * 2 + pl.col("b")).alias("d")
)
assert_gpu_result_equal(query, engine=engine)


def test_select_reduce_raises(df, engine):
query = df.select(
(pl.col("a") + pl.col("b")).max(),
(pl.col("a") * 2 + pl.col("b")).alias("d").mean(),
)
with pytest.raises(
pl.exceptions.ComputeError,
match="NotImplementedError",
):
assert_gpu_result_equal(query, engine=engine)
Comment on lines +44 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer assert_ir_translation_raises for these kind of things, I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parallel execution is currently independent of IR translation. When we raise an error, it's because we ran into a non-"pointwise" Select operation (with multiple partitions) after the IR was already translated successfully.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, thanks.

I think for now this is fine, perhaps in the parallel execution environment we don't want "early/eager" fallback. But it might be worthwhile thinking about.

We can think of this lowering as another step in the "IR translation" phase.



def test_select_with_cse_no_agg(df, engine):
expr = pl.col("a") + pl.col("a")
query = df.select(expr, (expr * 2).alias("b"), ((expr * 2) + 10).alias("c"))
assert_gpu_result_equal(query, engine=engine)
Loading