Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract polars function expression nodes to ensure they are serializable #17418

Merged
merged 14 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 56 additions & 21 deletions python/cudf_polars/cudf_polars/dsl/expressions/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

from __future__ import annotations

from enum import IntEnum, auto
from functools import partial, reduce
from typing import TYPE_CHECKING, Any, ClassVar

import pyarrow as pa

from polars.polars import _expr_nodes as pl_expr

import pylibcudf as plc

from cudf_polars.containers import Column
Expand All @@ -24,29 +23,65 @@
if TYPE_CHECKING:
from collections.abc import Mapping

from typing_extensions import Self

import polars.type_aliases as pl_types
from polars.polars import _expr_nodes as pl_expr

from cudf_polars.containers import DataFrame

__all__ = ["BooleanFunction"]


class BooleanFunction(Expr):
class Name(IntEnum):
"""Internal and picklable representation of polars' `BooleanFunction`."""

All = auto()
AllHorizontal = auto()
Any = auto()
AnyHorizontal = auto()
IsBetween = auto()
IsDuplicated = auto()
IsFinite = auto()
IsFirstDistinct = auto()
IsIn = auto()
IsInfinite = auto()
IsLastDistinct = auto()
IsNan = auto()
IsNotNan = auto()
IsNotNull = auto()
IsNull = auto()
IsUnique = auto()
Not = auto()

@classmethod
def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self:
"""Convert from polars' `BooleanFunction`."""
try:
function, name = str(obj).split(".", maxsplit=1)
except ValueError:
# Failed to unpack string
function = None
if function != "BooleanFunction":
raise ValueError("BooleanFunction required")
return getattr(cls, name)

__slots__ = ("name", "options")
_non_child = ("dtype", "name", "options")

def __init__(
self,
dtype: plc.DataType,
name: pl_expr.BooleanFunction,
name: BooleanFunction.Name,
options: tuple[Any, ...],
*children: Expr,
) -> None:
self.dtype = dtype
self.options = options
self.name = name
self.children = children
if self.name == pl_expr.BooleanFunction.IsIn and not all(
if self.name is BooleanFunction.Name.IsIn and not all(
c.dtype == self.children[0].dtype for c in self.children
):
# TODO: If polars IR doesn't put the casts in, we need to
Expand Down Expand Up @@ -110,12 +145,12 @@ def do_evaluate(
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if self.name in (
pl_expr.BooleanFunction.IsFinite,
pl_expr.BooleanFunction.IsInfinite,
BooleanFunction.Name.IsFinite,
BooleanFunction.Name.IsInfinite,
):
# Avoid evaluating the child if the dtype tells us it's unnecessary.
(child,) = self.children
is_finite = self.name == pl_expr.BooleanFunction.IsFinite
is_finite = self.name is BooleanFunction.Name.IsFinite
if child.dtype.id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64):
value = plc.interop.from_arrow(
pa.scalar(value=is_finite, type=plc.interop.to_arrow(self.dtype))
Expand All @@ -142,10 +177,10 @@ def do_evaluate(
]
# Kleene logic for Any (OR) and All (AND) if ignore_nulls is
# False
if self.name in (pl_expr.BooleanFunction.Any, pl_expr.BooleanFunction.All):
if self.name in (BooleanFunction.Name.Any, BooleanFunction.Name.All):
(ignore_nulls,) = self.options
(column,) = columns
is_any = self.name == pl_expr.BooleanFunction.Any
is_any = self.name is BooleanFunction.Name.Any
agg = plc.aggregation.any() if is_any else plc.aggregation.all()
result = plc.reduce.reduce(column.obj, agg, self.dtype)
if not ignore_nulls and column.obj.null_count() > 0:
Expand All @@ -165,27 +200,27 @@ def do_evaluate(
# False || Null => Null True && Null => Null
return Column(plc.Column.all_null_like(column.obj, 1))
return Column(plc.Column.from_scalar(result, 1))
if self.name == pl_expr.BooleanFunction.IsNull:
if self.name is BooleanFunction.Name.IsNull:
(column,) = columns
return Column(plc.unary.is_null(column.obj))
elif self.name == pl_expr.BooleanFunction.IsNotNull:
elif self.name is BooleanFunction.Name.IsNotNull:
(column,) = columns
return Column(plc.unary.is_valid(column.obj))
elif self.name == pl_expr.BooleanFunction.IsNan:
elif self.name is BooleanFunction.Name.IsNan:
(column,) = columns
return Column(
plc.unary.is_nan(column.obj).with_mask(
column.obj.null_mask(), column.obj.null_count()
)
)
elif self.name == pl_expr.BooleanFunction.IsNotNan:
elif self.name is BooleanFunction.Name.IsNotNan:
(column,) = columns
return Column(
plc.unary.is_not_nan(column.obj).with_mask(
column.obj.null_mask(), column.obj.null_count()
)
)
elif self.name == pl_expr.BooleanFunction.IsFirstDistinct:
elif self.name is BooleanFunction.Name.IsFirstDistinct:
(column,) = columns
return self._distinct(
column,
Expand All @@ -197,7 +232,7 @@ def do_evaluate(
pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype))
),
)
elif self.name == pl_expr.BooleanFunction.IsLastDistinct:
elif self.name is BooleanFunction.Name.IsLastDistinct:
(column,) = columns
return self._distinct(
column,
Expand All @@ -209,7 +244,7 @@ def do_evaluate(
pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype))
),
)
elif self.name == pl_expr.BooleanFunction.IsUnique:
elif self.name is BooleanFunction.Name.IsUnique:
(column,) = columns
return self._distinct(
column,
Expand All @@ -221,7 +256,7 @@ def do_evaluate(
pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype))
),
)
elif self.name == pl_expr.BooleanFunction.IsDuplicated:
elif self.name is BooleanFunction.Name.IsDuplicated:
(column,) = columns
return self._distinct(
column,
Expand All @@ -233,7 +268,7 @@ def do_evaluate(
pa.scalar(value=True, type=plc.interop.to_arrow(self.dtype))
),
)
elif self.name == pl_expr.BooleanFunction.AllHorizontal:
elif self.name is BooleanFunction.Name.AllHorizontal:
return Column(
reduce(
partial(
Expand All @@ -244,7 +279,7 @@ def do_evaluate(
(c.obj for c in columns),
)
)
elif self.name == pl_expr.BooleanFunction.AnyHorizontal:
elif self.name is BooleanFunction.Name.AnyHorizontal:
return Column(
reduce(
partial(
Expand All @@ -255,10 +290,10 @@ def do_evaluate(
(c.obj for c in columns),
)
)
elif self.name == pl_expr.BooleanFunction.IsIn:
elif self.name is BooleanFunction.Name.IsIn:
needles, haystack = columns
return Column(plc.search.contains(haystack.obj, needles.obj))
elif self.name == pl_expr.BooleanFunction.Not:
elif self.name is BooleanFunction.Name.Not:
(column,) = columns
return Column(
plc.unary.unary_operation(column.obj, plc.unary.UnaryOperator.NOT)
Expand Down
98 changes: 79 additions & 19 deletions python/cudf_polars/cudf_polars/dsl/expressions/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@

from __future__ import annotations

from enum import IntEnum, auto
from typing import TYPE_CHECKING, Any, ClassVar

import pyarrow as pa

from polars.polars import _expr_nodes as pl_expr

import pylibcudf as plc

from cudf_polars.containers import Column
Expand All @@ -20,33 +19,94 @@
if TYPE_CHECKING:
from collections.abc import Mapping

from typing_extensions import Self

from polars.polars import _expr_nodes as pl_expr

from cudf_polars.containers import DataFrame

__all__ = ["TemporalFunction"]


class TemporalFunction(Expr):
class Name(IntEnum):
"""Internal and picklable representation of polars' `TemporalFunction`."""

BaseUtcOffset = auto()
CastTimeUnit = auto()
Century = auto()
Combine = auto()
ConvertTimeZone = auto()
DSTOffset = auto()
Date = auto()
Datetime = auto()
DatetimeFunction = auto()
Day = auto()
Duration = auto()
Hour = auto()
IsLeapYear = auto()
IsoYear = auto()
Microsecond = auto()
Millennium = auto()
Millisecond = auto()
Minute = auto()
Month = auto()
MonthEnd = auto()
MonthStart = auto()
Nanosecond = auto()
OffsetBy = auto()
OrdinalDay = auto()
Quarter = auto()
ReplaceTimeZone = auto()
Round = auto()
Second = auto()
Time = auto()
TimeStamp = auto()
ToString = auto()
TotalDays = auto()
TotalHours = auto()
TotalMicroseconds = auto()
TotalMilliseconds = auto()
TotalMinutes = auto()
TotalNanoseconds = auto()
TotalSeconds = auto()
Truncate = auto()
Week = auto()
WeekDay = auto()
WithTimeUnit = auto()
Year = auto()

@classmethod
def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self:
"""Convert from polars' `TemporalFunction`."""
try:
function, name = str(obj).split(".", maxsplit=1)
except ValueError:
# Failed to unpack string
function = None
if function != "TemporalFunction":
raise ValueError("TemporalFunction required")
return getattr(cls, name)

__slots__ = ("name", "options")
_COMPONENT_MAP: ClassVar[
dict[pl_expr.TemporalFunction, plc.datetime.DatetimeComponent]
] = {
pl_expr.TemporalFunction.Year: plc.datetime.DatetimeComponent.YEAR,
pl_expr.TemporalFunction.Month: plc.datetime.DatetimeComponent.MONTH,
pl_expr.TemporalFunction.Day: plc.datetime.DatetimeComponent.DAY,
pl_expr.TemporalFunction.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY,
pl_expr.TemporalFunction.Hour: plc.datetime.DatetimeComponent.HOUR,
pl_expr.TemporalFunction.Minute: plc.datetime.DatetimeComponent.MINUTE,
pl_expr.TemporalFunction.Second: plc.datetime.DatetimeComponent.SECOND,
pl_expr.TemporalFunction.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND,
pl_expr.TemporalFunction.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND,
pl_expr.TemporalFunction.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND,
}
_non_child = ("dtype", "name", "options")
_COMPONENT_MAP: ClassVar[dict[Name, plc.datetime.DatetimeComponent]] = {
Name.Year: plc.datetime.DatetimeComponent.YEAR,
Name.Month: plc.datetime.DatetimeComponent.MONTH,
Name.Day: plc.datetime.DatetimeComponent.DAY,
Name.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY,
Name.Hour: plc.datetime.DatetimeComponent.HOUR,
Name.Minute: plc.datetime.DatetimeComponent.MINUTE,
Name.Second: plc.datetime.DatetimeComponent.SECOND,
Name.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND,
Name.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND,
Name.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND,
}

def __init__(
self,
dtype: plc.DataType,
name: pl_expr.TemporalFunction,
name: TemporalFunction.Name,
options: tuple[Any, ...],
*children: Expr,
) -> None:
Expand All @@ -70,7 +130,7 @@ def do_evaluate(
for child in self.children
]
(column,) = columns
if self.name == pl_expr.TemporalFunction.Microsecond:
if self.name is TemporalFunction.Name.Microsecond:
millis = plc.datetime.extract_datetime_component(
column.obj, plc.datetime.DatetimeComponent.MILLISECOND
)
Expand All @@ -90,7 +150,7 @@ def do_evaluate(
plc.types.DataType(plc.types.TypeId.INT32),
)
return Column(total_micros)
elif self.name == pl_expr.TemporalFunction.Nanosecond:
elif self.name is TemporalFunction.Name.Nanosecond:
millis = plc.datetime.extract_datetime_component(
column.obj, plc.datetime.DatetimeComponent.MILLISECOND
)
Expand Down
Loading
Loading