Skip to content

Commit

Permalink
Move types to internal classes Name and use IntEnum
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Nov 22, 2024
1 parent be7fa52 commit 3f69d98
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 215 deletions.
92 changes: 47 additions & 45 deletions python/cudf_polars/cudf_polars/dsl/expressions/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from enum import Enum, auto
from enum import IntEnum, auto
from functools import partial, reduce
from typing import TYPE_CHECKING, Any, ClassVar

Expand All @@ -33,49 +33,51 @@
__all__ = ["BooleanFunction"]


class BooleanFunctionName(Enum):
All = auto()
AllHorizontal = auto()
Any = auto()
AnyHorizontal = auto()
IsBetween = auto()
IsDuplicated = auto()
IsFinite = auto()
IsFirstDistinct = auto()
IsIn = auto()
IsInfinite = auto()
IsLastDistinct = auto()
IsNan = auto()
IsNotNan = auto()
IsNotNull = auto()
IsNull = auto()
IsUnique = auto()
Not = auto()
class BooleanFunction(Expr):
class Name(IntEnum):
"""Internal and picklable representation of polars' `BooleanFunction`."""

@classmethod
def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self:
function, name = str(obj).split(".", maxsplit=1)
if function != "BooleanFunction":
raise ValueError("BooleanFunction required")
return getattr(cls, name)
All = auto()
AllHorizontal = auto()
Any = auto()
AnyHorizontal = auto()
IsBetween = auto()
IsDuplicated = auto()
IsFinite = auto()
IsFirstDistinct = auto()
IsIn = auto()
IsInfinite = auto()
IsLastDistinct = auto()
IsNan = auto()
IsNotNan = auto()
IsNotNull = auto()
IsNull = auto()
IsUnique = auto()
Not = auto()

@classmethod
def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self:
"""Convert from polars' `BooleanFunction`."""
function, name = str(obj).split(".", maxsplit=1)
if function != "BooleanFunction":
raise ValueError("BooleanFunction required")
return getattr(cls, name)

class BooleanFunction(Expr):
__slots__ = ("name", "options")
_non_child = ("dtype", "name", "options")

def __init__(
self,
dtype: plc.DataType,
name: BooleanFunctionName,
name: BooleanFunction.Name,
options: tuple[Any, ...],
*children: Expr,
) -> None:
self.dtype = dtype
self.options = options
self.name = name
self.children = children
if self.name == BooleanFunctionName.IsIn and not all(
if self.name == BooleanFunction.Name.IsIn and not all(
c.dtype == self.children[0].dtype for c in self.children
):
# TODO: If polars IR doesn't put the casts in, we need to
Expand Down Expand Up @@ -139,12 +141,12 @@ def do_evaluate(
) -> Column:
"""Evaluate this expression given a dataframe for context."""
if self.name in (
BooleanFunctionName.IsFinite,
BooleanFunctionName.IsInfinite,
BooleanFunction.Name.IsFinite,
BooleanFunction.Name.IsInfinite,
):
# Avoid evaluating the child if the dtype tells us it's unnecessary.
(child,) = self.children
is_finite = self.name == BooleanFunctionName.IsFinite
is_finite = self.name == BooleanFunction.Name.IsFinite
if child.dtype.id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64):
value = plc.interop.from_arrow(
pa.scalar(value=is_finite, type=plc.interop.to_arrow(self.dtype))
Expand All @@ -171,10 +173,10 @@ def do_evaluate(
]
# Kleene logic for Any (OR) and All (AND) if ignore_nulls is
# False
if self.name in (BooleanFunctionName.Any, BooleanFunctionName.All):
if self.name in (BooleanFunction.Name.Any, BooleanFunction.Name.All):
(ignore_nulls,) = self.options
(column,) = columns
is_any = self.name == BooleanFunctionName.Any
is_any = self.name == BooleanFunction.Name.Any
agg = plc.aggregation.any() if is_any else plc.aggregation.all()
result = plc.reduce.reduce(column.obj, agg, self.dtype)
if not ignore_nulls and column.obj.null_count() > 0:
Expand All @@ -194,27 +196,27 @@ def do_evaluate(
# False || Null => Null True && Null => Null
return Column(plc.Column.all_null_like(column.obj, 1))
return Column(plc.Column.from_scalar(result, 1))
if self.name == BooleanFunctionName.IsNull:
if self.name == BooleanFunction.Name.IsNull:
(column,) = columns
return Column(plc.unary.is_null(column.obj))
elif self.name == BooleanFunctionName.IsNotNull:
elif self.name == BooleanFunction.Name.IsNotNull:
(column,) = columns
return Column(plc.unary.is_valid(column.obj))
elif self.name == BooleanFunctionName.IsNan:
elif self.name == BooleanFunction.Name.IsNan:
(column,) = columns
return Column(
plc.unary.is_nan(column.obj).with_mask(
column.obj.null_mask(), column.obj.null_count()
)
)
elif self.name == BooleanFunctionName.IsNotNan:
elif self.name == BooleanFunction.Name.IsNotNan:
(column,) = columns
return Column(
plc.unary.is_not_nan(column.obj).with_mask(
column.obj.null_mask(), column.obj.null_count()
)
)
elif self.name == BooleanFunctionName.IsFirstDistinct:
elif self.name == BooleanFunction.Name.IsFirstDistinct:
(column,) = columns
return self._distinct(
column,
Expand All @@ -226,7 +228,7 @@ def do_evaluate(
pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype))
),
)
elif self.name == BooleanFunctionName.IsLastDistinct:
elif self.name == BooleanFunction.Name.IsLastDistinct:
(column,) = columns
return self._distinct(
column,
Expand All @@ -238,7 +240,7 @@ def do_evaluate(
pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype))
),
)
elif self.name == BooleanFunctionName.IsUnique:
elif self.name == BooleanFunction.Name.IsUnique:
(column,) = columns
return self._distinct(
column,
Expand All @@ -250,7 +252,7 @@ def do_evaluate(
pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype))
),
)
elif self.name == BooleanFunctionName.IsDuplicated:
elif self.name == BooleanFunction.Name.IsDuplicated:
(column,) = columns
return self._distinct(
column,
Expand All @@ -262,7 +264,7 @@ def do_evaluate(
pa.scalar(value=True, type=plc.interop.to_arrow(self.dtype))
),
)
elif self.name == BooleanFunctionName.AllHorizontal:
elif self.name == BooleanFunction.Name.AllHorizontal:
return Column(
reduce(
partial(
Expand All @@ -273,7 +275,7 @@ def do_evaluate(
(c.obj for c in columns),
)
)
elif self.name == BooleanFunctionName.AnyHorizontal:
elif self.name == BooleanFunction.Name.AnyHorizontal:
return Column(
reduce(
partial(
Expand All @@ -284,10 +286,10 @@ def do_evaluate(
(c.obj for c in columns),
)
)
elif self.name == BooleanFunctionName.IsIn:
elif self.name == BooleanFunction.Name.IsIn:
needles, haystack = columns
return Column(plc.search.contains(haystack.obj, needles.obj))
elif self.name == BooleanFunctionName.Not:
elif self.name == BooleanFunction.Name.Not:
(column,) = columns
return Column(
plc.unary.unary_operation(column.obj, plc.unary.UnaryOperator.NOT)
Expand Down
142 changes: 71 additions & 71 deletions python/cudf_polars/cudf_polars/dsl/expressions/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from enum import Enum, auto
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Any, ClassVar

import pyarrow as pa
Expand All @@ -28,81 +28,81 @@
__all__ = ["TemporalFunction"]


class TemporalFunctionName(Enum):
BaseUtcOffset = auto()
CastTimeUnit = auto()
Century = auto()
Combine = auto()
ConvertTimeZone = auto()
DSTOffset = auto()
Date = auto()
Datetime = auto()
DatetimeFunction = auto()
Day = auto()
Duration = auto()
Hour = auto()
IsLeapYear = auto()
IsoYear = auto()
Microsecond = auto()
Millennium = auto()
Millisecond = auto()
Minute = auto()
Month = auto()
MonthEnd = auto()
MonthStart = auto()
Nanosecond = auto()
OffsetBy = auto()
OrdinalDay = auto()
Quarter = auto()
ReplaceTimeZone = auto()
Round = auto()
Second = auto()
Time = auto()
TimeStamp = auto()
ToString = auto()
TotalDays = auto()
TotalHours = auto()
TotalMicroseconds = auto()
TotalMilliseconds = auto()
TotalMinutes = auto()
TotalNanoseconds = auto()
TotalSeconds = auto()
Truncate = auto()
Week = auto()
WeekDay = auto()
WithTimeUnit = auto()
Year = auto()

@classmethod
def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self:
function, name = str(obj).split(".", maxsplit=1)
if function != "TemporalFunction":
raise ValueError("TemporalFunction required")
return getattr(cls, name)


class TemporalFunction(Expr):
class Name(IntEnum):
"""Internal and picklable representation of polars' `TemporalFunction`."""

BaseUtcOffset = auto()
CastTimeUnit = auto()
Century = auto()
Combine = auto()
ConvertTimeZone = auto()
DSTOffset = auto()
Date = auto()
Datetime = auto()
DatetimeFunction = auto()
Day = auto()
Duration = auto()
Hour = auto()
IsLeapYear = auto()
IsoYear = auto()
Microsecond = auto()
Millennium = auto()
Millisecond = auto()
Minute = auto()
Month = auto()
MonthEnd = auto()
MonthStart = auto()
Nanosecond = auto()
OffsetBy = auto()
OrdinalDay = auto()
Quarter = auto()
ReplaceTimeZone = auto()
Round = auto()
Second = auto()
Time = auto()
TimeStamp = auto()
ToString = auto()
TotalDays = auto()
TotalHours = auto()
TotalMicroseconds = auto()
TotalMilliseconds = auto()
TotalMinutes = auto()
TotalNanoseconds = auto()
TotalSeconds = auto()
Truncate = auto()
Week = auto()
WeekDay = auto()
WithTimeUnit = auto()
Year = auto()

@classmethod
def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self:
"""Convert from polars' `TemporalFunction`."""
function, name = str(obj).split(".", maxsplit=1)
if function != "TemporalFunction":
raise ValueError("TemporalFunction required")
return getattr(cls, name)

__slots__ = ("name", "options")
_COMPONENT_MAP: ClassVar[
dict[TemporalFunctionName, plc.datetime.DatetimeComponent]
] = {
TemporalFunctionName.Year: plc.datetime.DatetimeComponent.YEAR,
TemporalFunctionName.Month: plc.datetime.DatetimeComponent.MONTH,
TemporalFunctionName.Day: plc.datetime.DatetimeComponent.DAY,
TemporalFunctionName.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY,
TemporalFunctionName.Hour: plc.datetime.DatetimeComponent.HOUR,
TemporalFunctionName.Minute: plc.datetime.DatetimeComponent.MINUTE,
TemporalFunctionName.Second: plc.datetime.DatetimeComponent.SECOND,
TemporalFunctionName.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND,
TemporalFunctionName.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND,
TemporalFunctionName.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND,
}
_non_child = ("dtype", "name", "options")
_COMPONENT_MAP: ClassVar[dict[Name, plc.datetime.DatetimeComponent]] = {
Name.Year: plc.datetime.DatetimeComponent.YEAR,
Name.Month: plc.datetime.DatetimeComponent.MONTH,
Name.Day: plc.datetime.DatetimeComponent.DAY,
Name.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY,
Name.Hour: plc.datetime.DatetimeComponent.HOUR,
Name.Minute: plc.datetime.DatetimeComponent.MINUTE,
Name.Second: plc.datetime.DatetimeComponent.SECOND,
Name.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND,
Name.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND,
Name.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND,
}

def __init__(
self,
dtype: plc.DataType,
name: TemporalFunctionName,
name: TemporalFunction.Name,
options: tuple[Any, ...],
*children: Expr,
) -> None:
Expand All @@ -126,7 +126,7 @@ def do_evaluate(
for child in self.children
]
(column,) = columns
if self.name == TemporalFunctionName.Microsecond:
if self.name == TemporalFunction.Name.Microsecond:
millis = plc.datetime.extract_datetime_component(
column.obj, plc.datetime.DatetimeComponent.MILLISECOND
)
Expand All @@ -146,7 +146,7 @@ def do_evaluate(
plc.types.DataType(plc.types.TypeId.INT32),
)
return Column(total_micros)
elif self.name == TemporalFunctionName.Nanosecond:
elif self.name == TemporalFunction.Name.Nanosecond:
millis = plc.datetime.extract_datetime_component(
column.obj, plc.datetime.DatetimeComponent.MILLISECOND
)
Expand Down
Loading

0 comments on commit 3f69d98

Please sign in to comment.