Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify treatment of Expr and IR nodes in cudf-polars DSL #17016

Merged
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@


class Agg(Expr):
__slots__ = ("name", "options", "op", "request", "children")
__slots__ = ("name", "options", "op", "request")
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(
self, dtype: plc.DataType, name: str, options: Any, *children: Expr
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.name = name
self.options = options
self.children = children
Expand Down
97 changes: 7 additions & 90 deletions python/cudf_polars/cudf_polars/dsl/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
import pylibcudf as plc

from cudf_polars.containers import Column
from cudf_polars.dsl.nodebase import Node

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from collections.abc import Mapping

from cudf_polars.containers import Column, DataFrame

Expand All @@ -32,100 +33,16 @@ class ExecutionContext(IntEnum):
ROLLING = enum.auto()


class Expr:
"""
An abstract expression object.
class Expr(Node["Expr"]):
"""An abstract expression object."""

This contains a (potentially empty) tuple of child expressions,
along with non-child data. For uniform reconstruction and
implementation of hashing and equality schemes, child classes need
to provide a certain amount of metadata when they are defined.
Specifically, the ``_non_child`` attribute must list, in-order,
the names of the slots that are passed to the constructor. The
constructor must take arguments in the order ``(*_non_child,
*children).``
"""

__slots__ = ("dtype", "_hash_value", "_repr_value")
__slots__ = ("dtype",)
dtype: plc.DataType
"""Data type of the expression."""
_hash_value: int
"""Caching slot for the hash of the expression."""
_repr_value: str
"""Caching slot for repr of the expression."""
children: tuple[Expr, ...] = ()
"""Children of the expression."""
# This annotation is needed because of https://github.com/python/mypy/issues/17981
_non_child: ClassVar[tuple[str, ...]] = ("dtype",)
"""Names of non-child data (not Exprs) for reconstruction."""

# Constructor must take arguments in order (*_non_child, *children)
def __init__(self, dtype: plc.DataType) -> None:
self.dtype = dtype
vyasr marked this conversation as resolved.
Show resolved Hide resolved

def _ctor_arguments(self, children: Sequence[Expr]) -> Sequence:
return (*(getattr(self, attr) for attr in self._non_child), *children)

def get_hash(self) -> int:
"""
Return the hash of this expr.

Override this in subclasses, rather than __hash__.

Returns
-------
The integer hash value.
"""
return hash((type(self), self._ctor_arguments(self.children)))

def __hash__(self) -> int:
"""Hash of an expression with caching."""
try:
return self._hash_value
except AttributeError:
self._hash_value = self.get_hash()
return self._hash_value

def is_equal(self, other: Any) -> bool:
"""
Equality of two expressions.

Override this in subclasses, rather than __eq__.

Parameter
---------
other
object to compare to

Returns
-------
True if the two expressions are equal, false otherwise.
"""
if type(self) is not type(other):
return False # pragma: no cover; __eq__ trips first
return self._ctor_arguments(self.children) == other._ctor_arguments(
other.children
)

def __eq__(self, other: Any) -> bool:
"""Equality of expressions."""
if type(self) is not type(other) or hash(self) != hash(other):
return False
else:
return self.is_equal(other)

def __ne__(self, other: Any) -> bool:
"""Inequality of expressions."""
return not self.__eq__(other)

def __repr__(self) -> str:
"""String representation of an expression with caching."""
try:
return self._repr_value
except AttributeError:
args = ", ".join(f"{arg!r}" for arg in self._ctor_arguments(self.children))
self._repr_value = f"{type(self).__name__}({args})"
return self._repr_value

def do_evaluate(
self,
df: DataFrame,
Expand Down Expand Up @@ -311,11 +228,11 @@ class Col(Expr):
__slots__ = ("name",)
_non_child = ("dtype", "name")
name: str
children: tuple[()]

def __init__(self, dtype: plc.DataType, name: str) -> None:
self.dtype = dtype
self.name = name
self.children = ()

def do_evaluate(
self,
Expand Down
5 changes: 2 additions & 3 deletions python/cudf_polars/cudf_polars/dsl/expressions/binaryop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@


class BinOp(Expr):
__slots__ = ("op", "children")
__slots__ = ("op",)
_non_child = ("dtype", "op")
children: tuple[Expr, Expr]

def __init__(
self,
Expand All @@ -35,7 +34,7 @@ def __init__(
left: Expr,
right: Expr,
) -> None:
super().__init__(dtype)
self.dtype = dtype
if plc.traits.is_boolean(self.dtype):
# For boolean output types, bitand and bitor implement
# boolean logic, so translate. bitxor also does, but the
Expand Down
5 changes: 2 additions & 3 deletions python/cudf_polars/cudf_polars/dsl/expressions/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@


class BooleanFunction(Expr):
__slots__ = ("name", "options", "children")
__slots__ = ("name", "options")
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(
self,
Expand All @@ -42,7 +41,7 @@ def __init__(
options: tuple[Any, ...],
*children: Expr,
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.name = name
self.children = children
Expand Down
5 changes: 2 additions & 3 deletions python/cudf_polars/cudf_polars/dsl/expressions/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


class TemporalFunction(Expr):
__slots__ = ("name", "options", "children")
__slots__ = ("name", "options")
_COMPONENT_MAP: ClassVar[dict[pl_expr.TemporalFunction, str]] = {
pl_expr.TemporalFunction.Year: plc.datetime.DatetimeComponent.YEAR,
pl_expr.TemporalFunction.Month: plc.datetime.DatetimeComponent.MONTH,
Expand All @@ -39,7 +39,6 @@ class TemporalFunction(Expr):
pl_expr.TemporalFunction.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND,
}
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(
self,
Expand All @@ -48,7 +47,7 @@ def __init__(
options: tuple[Any, ...],
*children: Expr,
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.name = name
self.children = children
Expand Down
14 changes: 7 additions & 7 deletions python/cudf_polars/cudf_polars/dsl/expressions/literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from cudf_polars.utils import dtypes

if TYPE_CHECKING:
from collections.abc import Mapping
from collections.abc import Hashable, Mapping

import pyarrow as pa

Expand All @@ -31,12 +31,12 @@ class Literal(Expr):
__slots__ = ("value",)
_non_child = ("dtype", "value")
value: pa.Scalar[Any]
children: tuple[()]

def __init__(self, dtype: plc.DataType, value: pa.Scalar[Any]) -> None:
super().__init__(dtype)
self.dtype = dtype
assert value.type == plc.interop.to_arrow(dtype)
self.value = value
self.children = ()

def do_evaluate(
self,
Expand All @@ -58,19 +58,19 @@ class LiteralColumn(Expr):
__slots__ = ("value",)
_non_child = ("dtype", "value")
value: pa.Array[Any, Any]
children: tuple[()]

def __init__(self, dtype: plc.DataType, value: pl.Series) -> None:
super().__init__(dtype)
self.dtype = dtype
data = value.to_arrow()
self.value = data.cast(dtypes.downcast_arrow_lists(data.type))
self.children = ()

def get_hash(self) -> int:
def get_hashable(self) -> Hashable:
"""Compute a hash of the column."""
# This is stricter than necessary, but we only need this hash
# for identity in groupby replacements so it's OK. And this
# way we avoid doing potentially expensive compute.
return hash((type(self), self.dtype, id(self.value)))
return (type(self), self.dtype, id(self.value))

def do_evaluate(
self,
Expand Down
10 changes: 4 additions & 6 deletions python/cudf_polars/cudf_polars/dsl/expressions/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,22 @@


class RollingWindow(Expr):
__slots__ = ("options", "children")
__slots__ = ("options",)
_non_child = ("dtype", "options")
children: tuple[Expr]

def __init__(self, dtype: plc.DataType, options: Any, agg: Expr) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.children = (agg,)
raise NotImplementedError("Rolling window not implemented")


class GroupedRollingWindow(Expr):
__slots__ = ("options", "children")
__slots__ = ("options",)
_non_child = ("dtype", "options")
children: tuple[Expr, ...]

def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.children = (agg, *by)
raise NotImplementedError("Grouped rolling window not implemented")
10 changes: 4 additions & 6 deletions python/cudf_polars/cudf_polars/dsl/expressions/selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@


class Gather(Expr):
__slots__ = ("children",)
__slots__ = ()
_non_child = ("dtype",)
children: tuple[Expr, Expr]

def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr) -> None:
super().__init__(dtype)
self.dtype = dtype
self.children = (values, indices)

def do_evaluate(
Expand Down Expand Up @@ -65,12 +64,11 @@ def do_evaluate(


class Filter(Expr):
__slots__ = ("children",)
__slots__ = ()
_non_child = ("dtype",)
children: tuple[Expr, Expr]

def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr):
super().__init__(dtype)
self.dtype = dtype
self.children = (values, indices)

def do_evaluate(
Expand Down
10 changes: 4 additions & 6 deletions python/cudf_polars/cudf_polars/dsl/expressions/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@


class Sort(Expr):
__slots__ = ("options", "children")
__slots__ = ("options",)
_non_child = ("dtype", "options")
children: tuple[Expr]

def __init__(
self, dtype: plc.DataType, options: tuple[bool, bool, bool], column: Expr
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.children = (column,)

Expand Down Expand Up @@ -59,9 +58,8 @@ def do_evaluate(


class SortBy(Expr):
__slots__ = ("options", "children")
__slots__ = ("options",)
_non_child = ("dtype", "options")
children: tuple[Expr, ...]

def __init__(
self,
Expand All @@ -70,7 +68,7 @@ def __init__(
column: Expr,
*by: Expr,
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.children = (column, *by)

Expand Down
5 changes: 2 additions & 3 deletions python/cudf_polars/cudf_polars/dsl/expressions/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@


class StringFunction(Expr):
__slots__ = ("name", "options", "children", "_regex_program")
__slots__ = ("name", "options", "_regex_program")
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(
self,
Expand All @@ -39,7 +38,7 @@ def __init__(
options: tuple[Any, ...],
*children: Expr,
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.options = options
self.name = name
self.children = children
Expand Down
5 changes: 2 additions & 3 deletions python/cudf_polars/cudf_polars/dsl/expressions/ternary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@


class Ternary(Expr):
__slots__ = ("children",)
__slots__ = ()
_non_child = ("dtype",)
children: tuple[Expr, Expr, Expr]

def __init__(
self, dtype: plc.DataType, when: Expr, then: Expr, otherwise: Expr
) -> None:
super().__init__(dtype)
self.dtype = dtype
self.children = (when, then, otherwise)

def do_evaluate(
Expand Down
Loading
Loading