Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract polars function expression nodes to ensure they are serializable #17418

Merged
merged 14 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 47 additions & 21 deletions python/cudf_polars/cudf_polars/dsl/expressions/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

from __future__ import annotations

from enum import Enum, auto
from functools import partial, reduce
from typing import TYPE_CHECKING, Any, ClassVar

import pyarrow as pa

from polars.polars import _expr_nodes as pl_expr

import pylibcudf as plc

from cudf_polars.containers import Column
Expand All @@ -31,22 +30,49 @@
__all__ = ["BooleanFunction"]


class BooleanFunctionName(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: use IntEnum.

Also, let's nest this class inside the BooleanFunction class below, and just call it Name.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 3f69d98 .

All = auto()
AllHorizontal = auto()
Any = auto()
AnyHorizontal = auto()
IsBetween = auto()
IsDuplicated = auto()
IsFinite = auto()
IsFirstDistinct = auto()
IsIn = auto()
IsInfinite = auto()
IsLastDistinct = auto()
IsNan = auto()
IsNotNan = auto()
IsNotNull = auto()
IsNull = auto()
IsUnique = auto()
Not = auto()

@staticmethod
def get_polars_type(tp: BooleanFunctionName):
function, name = str(tp).split(".")
if function != "BooleanFunction":
raise ValueError("BooleanFunction required")
return getattr(BooleanFunctionName, name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this doesn't return a polars object, but one of ours. Also, because we're calling getattr, that suggests we should make this a classmethod

@classmethod
def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self:
    function, name = str(obj).split(".", maxsplit=1)
    if function != "BooleanFunction":
        raise ValueError("BooleanFunction required")
    return getattr(cls, name)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I totally mix everything up here, thanks for catching. Done in be7fa52 .



class BooleanFunction(Expr):
__slots__ = ("name", "options")
_non_child = ("dtype", "name", "options")

def __init__(
self,
dtype: plc.DataType,
name: pl_expr.BooleanFunction,
name: BooleanFunctionName,
options: tuple[Any, ...],
*children: Expr,
) -> None:
self.dtype = dtype
self.options = options
self.name = name
self.children = children
if self.name == pl_expr.BooleanFunction.IsIn and not all(
if self.name == BooleanFunctionName.IsIn and not all(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Because we have enums, we should now use is for comparison.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 7d009f9 .

c.dtype == self.children[0].dtype for c in self.children
):
# TODO: If polars IR doesn't put the casts in, we need to
Expand Down Expand Up @@ -110,12 +136,12 @@ def do_evaluate(
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if self.name in (
pl_expr.BooleanFunction.IsFinite,
pl_expr.BooleanFunction.IsInfinite,
BooleanFunctionName.IsFinite,
BooleanFunctionName.IsInfinite,
):
# Avoid evaluating the child if the dtype tells us it's unnecessary.
(child,) = self.children
is_finite = self.name == pl_expr.BooleanFunction.IsFinite
is_finite = self.name == BooleanFunctionName.IsFinite
if child.dtype.id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64):
value = plc.interop.from_arrow(
pa.scalar(value=is_finite, type=plc.interop.to_arrow(self.dtype))
Expand All @@ -142,10 +168,10 @@ def do_evaluate(
]
# Kleene logic for Any (OR) and All (AND) if ignore_nulls is
# False
if self.name in (pl_expr.BooleanFunction.Any, pl_expr.BooleanFunction.All):
if self.name in (BooleanFunctionName.Any, BooleanFunctionName.All):
(ignore_nulls,) = self.options
(column,) = columns
is_any = self.name == pl_expr.BooleanFunction.Any
is_any = self.name == BooleanFunctionName.Any
agg = plc.aggregation.any() if is_any else plc.aggregation.all()
result = plc.reduce.reduce(column.obj, agg, self.dtype)
if not ignore_nulls and column.obj.null_count() > 0:
Expand All @@ -165,27 +191,27 @@ def do_evaluate(
# False || Null => Null True && Null => Null
return Column(plc.Column.all_null_like(column.obj, 1))
return Column(plc.Column.from_scalar(result, 1))
if self.name == pl_expr.BooleanFunction.IsNull:
if self.name == BooleanFunctionName.IsNull:
(column,) = columns
return Column(plc.unary.is_null(column.obj))
elif self.name == pl_expr.BooleanFunction.IsNotNull:
elif self.name == BooleanFunctionName.IsNotNull:
(column,) = columns
return Column(plc.unary.is_valid(column.obj))
elif self.name == pl_expr.BooleanFunction.IsNan:
elif self.name == BooleanFunctionName.IsNan:
(column,) = columns
return Column(
plc.unary.is_nan(column.obj).with_mask(
column.obj.null_mask(), column.obj.null_count()
)
)
elif self.name == pl_expr.BooleanFunction.IsNotNan:
elif self.name == BooleanFunctionName.IsNotNan:
(column,) = columns
return Column(
plc.unary.is_not_nan(column.obj).with_mask(
column.obj.null_mask(), column.obj.null_count()
)
)
elif self.name == pl_expr.BooleanFunction.IsFirstDistinct:
elif self.name == BooleanFunctionName.IsFirstDistinct:
(column,) = columns
return self._distinct(
column,
Expand All @@ -197,7 +223,7 @@ def do_evaluate(
pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype))
),
)
elif self.name == pl_expr.BooleanFunction.IsLastDistinct:
elif self.name == BooleanFunctionName.IsLastDistinct:
(column,) = columns
return self._distinct(
column,
Expand All @@ -209,7 +235,7 @@ def do_evaluate(
pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype))
),
)
elif self.name == pl_expr.BooleanFunction.IsUnique:
elif self.name == BooleanFunctionName.IsUnique:
(column,) = columns
return self._distinct(
column,
Expand All @@ -221,7 +247,7 @@ def do_evaluate(
pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype))
),
)
elif self.name == pl_expr.BooleanFunction.IsDuplicated:
elif self.name == BooleanFunctionName.IsDuplicated:
(column,) = columns
return self._distinct(
column,
Expand All @@ -233,7 +259,7 @@ def do_evaluate(
pa.scalar(value=True, type=plc.interop.to_arrow(self.dtype))
),
)
elif self.name == pl_expr.BooleanFunction.AllHorizontal:
elif self.name == BooleanFunctionName.AllHorizontal:
return Column(
reduce(
partial(
Expand All @@ -244,7 +270,7 @@ def do_evaluate(
(c.obj for c in columns),
)
)
elif self.name == pl_expr.BooleanFunction.AnyHorizontal:
elif self.name == BooleanFunctionName.AnyHorizontal:
return Column(
reduce(
partial(
Expand All @@ -255,10 +281,10 @@ def do_evaluate(
(c.obj for c in columns),
)
)
elif self.name == pl_expr.BooleanFunction.IsIn:
elif self.name == BooleanFunctionName.IsIn:
needles, haystack = columns
return Column(plc.search.contains(haystack.obj, needles.obj))
elif self.name == pl_expr.BooleanFunction.Not:
elif self.name == BooleanFunctionName.Not:
(column,) = columns
return Column(
plc.unary.unary_operation(column.obj, plc.unary.UnaryOperator.NOT)
Expand Down
84 changes: 68 additions & 16 deletions python/cudf_polars/cudf_polars/dsl/expressions/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@

from __future__ import annotations

from enum import Enum, auto
from typing import TYPE_CHECKING, Any, ClassVar

import pyarrow as pa

from polars.polars import _expr_nodes as pl_expr

import pylibcudf as plc

from cudf_polars.containers import Column
Expand All @@ -25,28 +24,81 @@
__all__ = ["TemporalFunction"]


class TemporalFunctionName(Enum):
BaseUtcOffset = auto()
CastTimeUnit = auto()
Century = auto()
Combine = auto()
ConvertTimeZone = auto()
DSTOffset = auto()
Date = auto()
Datetime = auto()
DatetimeFunction = auto()
Day = auto()
Duration = auto()
Hour = auto()
IsLeapYear = auto()
IsoYear = auto()
Microsecond = auto()
Millennium = auto()
Millisecond = auto()
Minute = auto()
Month = auto()
MonthEnd = auto()
MonthStart = auto()
Nanosecond = auto()
OffsetBy = auto()
OrdinalDay = auto()
Quarter = auto()
ReplaceTimeZone = auto()
Round = auto()
Second = auto()
Time = auto()
TimeStamp = auto()
ToString = auto()
TotalDays = auto()
TotalHours = auto()
TotalMicroseconds = auto()
TotalMilliseconds = auto()
TotalMinutes = auto()
TotalNanoseconds = auto()
TotalSeconds = auto()
Truncate = auto()
Week = auto()
WeekDay = auto()
WithTimeUnit = auto()
Year = auto()

@staticmethod
def get_polars_type(tp: TemporalFunctionName):
function, name = str(tp).split(".")
if function != "TemporalFunction":
raise ValueError("TemporalFunction required")
return getattr(TemporalFunctionName, name)


class TemporalFunction(Expr):
__slots__ = ("name", "options")
_COMPONENT_MAP: ClassVar[
dict[pl_expr.TemporalFunction, plc.datetime.DatetimeComponent]
dict[TemporalFunctionName, plc.datetime.DatetimeComponent]
] = {
pl_expr.TemporalFunction.Year: plc.datetime.DatetimeComponent.YEAR,
pl_expr.TemporalFunction.Month: plc.datetime.DatetimeComponent.MONTH,
pl_expr.TemporalFunction.Day: plc.datetime.DatetimeComponent.DAY,
pl_expr.TemporalFunction.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY,
pl_expr.TemporalFunction.Hour: plc.datetime.DatetimeComponent.HOUR,
pl_expr.TemporalFunction.Minute: plc.datetime.DatetimeComponent.MINUTE,
pl_expr.TemporalFunction.Second: plc.datetime.DatetimeComponent.SECOND,
pl_expr.TemporalFunction.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND,
pl_expr.TemporalFunction.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND,
pl_expr.TemporalFunction.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND,
TemporalFunctionName.Year: plc.datetime.DatetimeComponent.YEAR,
TemporalFunctionName.Month: plc.datetime.DatetimeComponent.MONTH,
TemporalFunctionName.Day: plc.datetime.DatetimeComponent.DAY,
TemporalFunctionName.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY,
TemporalFunctionName.Hour: plc.datetime.DatetimeComponent.HOUR,
TemporalFunctionName.Minute: plc.datetime.DatetimeComponent.MINUTE,
TemporalFunctionName.Second: plc.datetime.DatetimeComponent.SECOND,
TemporalFunctionName.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND,
TemporalFunctionName.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND,
TemporalFunctionName.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND,
}
_non_child = ("dtype", "name", "options")

def __init__(
self,
dtype: plc.DataType,
name: pl_expr.TemporalFunction,
name: TemporalFunctionName,
options: tuple[Any, ...],
*children: Expr,
) -> None:
Expand All @@ -70,7 +122,7 @@ def do_evaluate(
for child in self.children
]
(column,) = columns
if self.name == pl_expr.TemporalFunction.Microsecond:
if self.name == TemporalFunctionName.Microsecond:
millis = plc.datetime.extract_datetime_component(
column.obj, plc.datetime.DatetimeComponent.MILLISECOND
)
Expand All @@ -90,7 +142,7 @@ def do_evaluate(
plc.types.DataType(plc.types.TypeId.INT32),
)
return Column(total_micros)
elif self.name == pl_expr.TemporalFunction.Nanosecond:
elif self.name == TemporalFunctionName.Nanosecond:
millis = plc.datetime.extract_datetime_component(
column.obj, plc.datetime.DatetimeComponent.MILLISECOND
)
Expand Down
Loading
Loading