diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py index ac8169d516c..9cccd4d97d7 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py @@ -6,7 +6,7 @@ from __future__ import annotations -from enum import Enum, auto +from enum import IntEnum, auto from functools import partial, reduce from typing import TYPE_CHECKING, Any, ClassVar @@ -33,41 +33,43 @@ __all__ = ["BooleanFunction"] -class BooleanFunctionName(Enum): - All = auto() - AllHorizontal = auto() - Any = auto() - AnyHorizontal = auto() - IsBetween = auto() - IsDuplicated = auto() - IsFinite = auto() - IsFirstDistinct = auto() - IsIn = auto() - IsInfinite = auto() - IsLastDistinct = auto() - IsNan = auto() - IsNotNan = auto() - IsNotNull = auto() - IsNull = auto() - IsUnique = auto() - Not = auto() +class BooleanFunction(Expr): + class Name(IntEnum): + """Internal and picklable representation of polars' `BooleanFunction`.""" - @classmethod - def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self: - function, name = str(obj).split(".", maxsplit=1) - if function != "BooleanFunction": - raise ValueError("BooleanFunction required") - return getattr(cls, name) + All = auto() + AllHorizontal = auto() + Any = auto() + AnyHorizontal = auto() + IsBetween = auto() + IsDuplicated = auto() + IsFinite = auto() + IsFirstDistinct = auto() + IsIn = auto() + IsInfinite = auto() + IsLastDistinct = auto() + IsNan = auto() + IsNotNan = auto() + IsNotNull = auto() + IsNull = auto() + IsUnique = auto() + Not = auto() + @classmethod + def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self: + """Convert from polars' `BooleanFunction`.""" + function, name = str(obj).split(".", maxsplit=1) + if function != "BooleanFunction": + raise ValueError("BooleanFunction required") + return getattr(cls, name) -class BooleanFunction(Expr): __slots__ = ("name", "options") _non_child = ("dtype", "name", "options") def __init__( self, dtype: plc.DataType, - name: BooleanFunctionName, + name: BooleanFunction.Name, options: tuple[Any, ...], *children: Expr, ) -> None: @@ -75,7 +77,7 @@ def __init__( self.options = options self.name = name self.children = children - if self.name == BooleanFunctionName.IsIn and not all( + if self.name == BooleanFunction.Name.IsIn and not all( c.dtype == self.children[0].dtype for c in self.children ): # TODO: If polars IR doesn't put the casts in, we need to @@ -139,12 +141,12 @@ def do_evaluate( ) -> Column: """Evaluate this expression given a dataframe for context.""" if self.name in ( - BooleanFunctionName.IsFinite, - BooleanFunctionName.IsInfinite, + BooleanFunction.Name.IsFinite, + BooleanFunction.Name.IsInfinite, ): # Avoid evaluating the child if the dtype tells us it's unnecessary. (child,) = self.children - is_finite = self.name == BooleanFunctionName.IsFinite + is_finite = self.name == BooleanFunction.Name.IsFinite if child.dtype.id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64): value = plc.interop.from_arrow( pa.scalar(value=is_finite, type=plc.interop.to_arrow(self.dtype)) @@ -171,10 +173,10 @@ def do_evaluate( ] # Kleene logic for Any (OR) and All (AND) if ignore_nulls is # False - if self.name in (BooleanFunctionName.Any, BooleanFunctionName.All): + if self.name in (BooleanFunction.Name.Any, BooleanFunction.Name.All): (ignore_nulls,) = self.options (column,) = columns - is_any = self.name == BooleanFunctionName.Any + is_any = self.name == BooleanFunction.Name.Any agg = plc.aggregation.any() if is_any else plc.aggregation.all() result = plc.reduce.reduce(column.obj, agg, self.dtype) if not ignore_nulls and column.obj.null_count() > 0: @@ -194,27 +196,27 @@ def do_evaluate( # False || Null => Null True && Null => Null return Column(plc.Column.all_null_like(column.obj, 1)) return Column(plc.Column.from_scalar(result, 1)) - if self.name == BooleanFunctionName.IsNull: + if self.name == BooleanFunction.Name.IsNull: (column,) = columns return Column(plc.unary.is_null(column.obj)) - elif self.name == BooleanFunctionName.IsNotNull: + elif self.name == BooleanFunction.Name.IsNotNull: (column,) = columns return Column(plc.unary.is_valid(column.obj)) - elif self.name == BooleanFunctionName.IsNan: + elif self.name == BooleanFunction.Name.IsNan: (column,) = columns return Column( plc.unary.is_nan(column.obj).with_mask( column.obj.null_mask(), column.obj.null_count() ) ) - elif self.name == BooleanFunctionName.IsNotNan: + elif self.name == BooleanFunction.Name.IsNotNan: (column,) = columns return Column( plc.unary.is_not_nan(column.obj).with_mask( column.obj.null_mask(), column.obj.null_count() ) ) - elif self.name == BooleanFunctionName.IsFirstDistinct: + elif self.name == BooleanFunction.Name.IsFirstDistinct: (column,) = columns return self._distinct( column, @@ -226,7 +228,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == BooleanFunctionName.IsLastDistinct: + elif self.name == BooleanFunction.Name.IsLastDistinct: (column,) = columns return self._distinct( column, @@ -238,7 +240,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == BooleanFunctionName.IsUnique: + elif self.name == BooleanFunction.Name.IsUnique: (column,) = columns return self._distinct( column, @@ -250,7 +252,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == BooleanFunctionName.IsDuplicated: + elif self.name == BooleanFunction.Name.IsDuplicated: (column,) = columns return self._distinct( column, @@ -262,7 +264,7 @@ def do_evaluate( pa.scalar(value=True, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == BooleanFunctionName.AllHorizontal: + elif self.name == BooleanFunction.Name.AllHorizontal: return Column( reduce( partial( @@ -273,7 +275,7 @@ def do_evaluate( (c.obj for c in columns), ) ) - elif self.name == BooleanFunctionName.AnyHorizontal: + elif self.name == BooleanFunction.Name.AnyHorizontal: return Column( reduce( partial( @@ -284,10 +286,10 @@ def do_evaluate( (c.obj for c in columns), ) ) - elif self.name == BooleanFunctionName.IsIn: + elif self.name == BooleanFunction.Name.IsIn: needles, haystack = columns return Column(plc.search.contains(haystack.obj, needles.obj)) - elif self.name == BooleanFunctionName.Not: + elif self.name == BooleanFunction.Name.Not: (column,) = columns return Column( plc.unary.unary_operation(column.obj, plc.unary.UnaryOperator.NOT) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py index 7ee53f52306..d37318eb71b 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py @@ -6,7 +6,7 @@ from __future__ import annotations -from enum import Enum, auto +from enum import IntEnum, auto from typing import TYPE_CHECKING, Any, ClassVar import pyarrow as pa @@ -28,81 +28,81 @@ __all__ = ["TemporalFunction"] -class TemporalFunctionName(Enum): - BaseUtcOffset = auto() - CastTimeUnit = auto() - Century = auto() - Combine = auto() - ConvertTimeZone = auto() - DSTOffset = auto() - Date = auto() - Datetime = auto() - DatetimeFunction = auto() - Day = auto() - Duration = auto() - Hour = auto() - IsLeapYear = auto() - IsoYear = auto() - Microsecond = auto() - Millennium = auto() - Millisecond = auto() - Minute = auto() - Month = auto() - MonthEnd = auto() - MonthStart = auto() - Nanosecond = auto() - OffsetBy = auto() - OrdinalDay = auto() - Quarter = auto() - ReplaceTimeZone = auto() - Round = auto() - Second = auto() - Time = auto() - TimeStamp = auto() - ToString = auto() - TotalDays = auto() - TotalHours = auto() - TotalMicroseconds = auto() - TotalMilliseconds = auto() - TotalMinutes = auto() - TotalNanoseconds = auto() - TotalSeconds = auto() - Truncate = auto() - Week = auto() - WeekDay = auto() - WithTimeUnit = auto() - Year = auto() - - @classmethod - def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self: - function, name = str(obj).split(".", maxsplit=1) - if function != "TemporalFunction": - raise ValueError("TemporalFunction required") - return getattr(cls, name) - - class TemporalFunction(Expr): + class Name(IntEnum): + """Internal and picklable representation of polars' `TemporalFunction`.""" + + BaseUtcOffset = auto() + CastTimeUnit = auto() + Century = auto() + Combine = auto() + ConvertTimeZone = auto() + DSTOffset = auto() + Date = auto() + Datetime = auto() + DatetimeFunction = auto() + Day = auto() + Duration = auto() + Hour = auto() + IsLeapYear = auto() + IsoYear = auto() + Microsecond = auto() + Millennium = auto() + Millisecond = auto() + Minute = auto() + Month = auto() + MonthEnd = auto() + MonthStart = auto() + Nanosecond = auto() + OffsetBy = auto() + OrdinalDay = auto() + Quarter = auto() + ReplaceTimeZone = auto() + Round = auto() + Second = auto() + Time = auto() + TimeStamp = auto() + ToString = auto() + TotalDays = auto() + TotalHours = auto() + TotalMicroseconds = auto() + TotalMilliseconds = auto() + TotalMinutes = auto() + TotalNanoseconds = auto() + TotalSeconds = auto() + Truncate = auto() + Week = auto() + WeekDay = auto() + WithTimeUnit = auto() + Year = auto() + + @classmethod + def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self: + """Convert from polars' `TemporalFunction`.""" + function, name = str(obj).split(".", maxsplit=1) + if function != "TemporalFunction": + raise ValueError("TemporalFunction required") + return getattr(cls, name) + __slots__ = ("name", "options") - _COMPONENT_MAP: ClassVar[ - dict[TemporalFunctionName, plc.datetime.DatetimeComponent] - ] = { - TemporalFunctionName.Year: plc.datetime.DatetimeComponent.YEAR, - TemporalFunctionName.Month: plc.datetime.DatetimeComponent.MONTH, - TemporalFunctionName.Day: plc.datetime.DatetimeComponent.DAY, - TemporalFunctionName.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY, - TemporalFunctionName.Hour: plc.datetime.DatetimeComponent.HOUR, - TemporalFunctionName.Minute: plc.datetime.DatetimeComponent.MINUTE, - TemporalFunctionName.Second: plc.datetime.DatetimeComponent.SECOND, - TemporalFunctionName.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND, - TemporalFunctionName.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND, - TemporalFunctionName.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND, - } _non_child = ("dtype", "name", "options") + _COMPONENT_MAP: ClassVar[dict[Name, plc.datetime.DatetimeComponent]] = { + Name.Year: plc.datetime.DatetimeComponent.YEAR, + Name.Month: plc.datetime.DatetimeComponent.MONTH, + Name.Day: plc.datetime.DatetimeComponent.DAY, + Name.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY, + Name.Hour: plc.datetime.DatetimeComponent.HOUR, + Name.Minute: plc.datetime.DatetimeComponent.MINUTE, + Name.Second: plc.datetime.DatetimeComponent.SECOND, + Name.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND, + Name.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND, + Name.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND, + } def __init__( self, dtype: plc.DataType, - name: TemporalFunctionName, + name: TemporalFunction.Name, options: tuple[Any, ...], *children: Expr, ) -> None: @@ -126,7 +126,7 @@ def do_evaluate( for child in self.children ] (column,) = columns - if self.name == TemporalFunctionName.Microsecond: + if self.name == TemporalFunction.Name.Microsecond: millis = plc.datetime.extract_datetime_component( column.obj, plc.datetime.DatetimeComponent.MILLISECOND ) @@ -146,7 +146,7 @@ def do_evaluate( plc.types.DataType(plc.types.TypeId.INT32), ) return Column(total_micros) - elif self.name == TemporalFunctionName.Nanosecond: + elif self.name == TemporalFunction.Name.Nanosecond: millis = plc.datetime.extract_datetime_component( column.obj, plc.datetime.DatetimeComponent.MILLISECOND ) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/string.py b/python/cudf_polars/cudf_polars/dsl/expressions/string.py index 4368fc40bd3..ebf7dfee18b 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/string.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/string.py @@ -6,7 +6,7 @@ from __future__ import annotations -from enum import Enum, auto +from enum import IntEnum, auto from typing import TYPE_CHECKING, Any import pyarrow as pa @@ -32,67 +32,69 @@ __all__ = ["StringFunction"] -class StringFunctionName(Enum): - Base64Decode = auto() - Base64Encode = auto() - ConcatHorizontal = auto() - ConcatVertical = auto() - Contains = auto() - ContainsMany = auto() - CountMatches = auto() - EndsWith = auto() - EscapeRegex = auto() - Extract = auto() - ExtractAll = auto() - ExtractGroups = auto() - Find = auto() - Head = auto() - HexDecode = auto() - HexEncode = auto() - JsonDecode = auto() - JsonPathMatch = auto() - LenBytes = auto() - LenChars = auto() - Lowercase = auto() - PadEnd = auto() - PadStart = auto() - Replace = auto() - ReplaceMany = auto() - Reverse = auto() - Slice = auto() - Split = auto() - SplitExact = auto() - SplitN = auto() - StartsWith = auto() - StripChars = auto() - StripCharsEnd = auto() - StripCharsStart = auto() - StripPrefix = auto() - StripSuffix = auto() - Strptime = auto() - Tail = auto() - Titlecase = auto() - ToDecimal = auto() - ToInteger = auto() - Uppercase = auto() - ZFill = auto() +class StringFunction(Expr): + class Name(IntEnum): + """Internal and picklable representation of polars' `StringFunction`.""" - @classmethod - def from_polars(cls, obj: pl_expr.StringFunction) -> Self: - function, name = str(obj).split(".", maxsplit=1) - if function != "StringFunction": - raise ValueError("StringFunction required") - return getattr(cls, name) + Base64Decode = auto() + Base64Encode = auto() + ConcatHorizontal = auto() + ConcatVertical = auto() + Contains = auto() + ContainsMany = auto() + CountMatches = auto() + EndsWith = auto() + EscapeRegex = auto() + Extract = auto() + ExtractAll = auto() + ExtractGroups = auto() + Find = auto() + Head = auto() + HexDecode = auto() + HexEncode = auto() + JsonDecode = auto() + JsonPathMatch = auto() + LenBytes = auto() + LenChars = auto() + Lowercase = auto() + PadEnd = auto() + PadStart = auto() + Replace = auto() + ReplaceMany = auto() + Reverse = auto() + Slice = auto() + Split = auto() + SplitExact = auto() + SplitN = auto() + StartsWith = auto() + StripChars = auto() + StripCharsEnd = auto() + StripCharsStart = auto() + StripPrefix = auto() + StripSuffix = auto() + Strptime = auto() + Tail = auto() + Titlecase = auto() + ToDecimal = auto() + ToInteger = auto() + Uppercase = auto() + ZFill = auto() + @classmethod + def from_polars(cls, obj: pl_expr.StringFunction) -> Self: + """Convert from polars' `StringFunction`.""" + function, name = str(obj).split(".", maxsplit=1) + if function != "StringFunction": + raise ValueError("StringFunction required") + return getattr(cls, name) -class StringFunction(Expr): __slots__ = ("name", "options", "_regex_program") _non_child = ("dtype", "name", "options") def __init__( self, dtype: plc.DataType, - name: StringFunctionName, + name: StringFunction.Name, options: tuple[Any, ...], *children: Expr, ) -> None: @@ -104,21 +106,21 @@ def __init__( def _validate_input(self): if self.name not in ( - StringFunctionName.Contains, - StringFunctionName.EndsWith, - StringFunctionName.Lowercase, - StringFunctionName.Replace, - StringFunctionName.ReplaceMany, - StringFunctionName.Slice, - StringFunctionName.Strptime, - StringFunctionName.StartsWith, - StringFunctionName.StripChars, - StringFunctionName.StripCharsStart, - StringFunctionName.StripCharsEnd, - StringFunctionName.Uppercase, + StringFunction.Name.Contains, + StringFunction.Name.EndsWith, + StringFunction.Name.Lowercase, + StringFunction.Name.Replace, + StringFunction.Name.ReplaceMany, + StringFunction.Name.Slice, + StringFunction.Name.Strptime, + StringFunction.Name.StartsWith, + StringFunction.Name.StripChars, + StringFunction.Name.StripCharsStart, + StringFunction.Name.StripCharsEnd, + StringFunction.Name.Uppercase, ): raise NotImplementedError(f"String function {self.name}") - if self.name == StringFunctionName.Contains: + if self.name == StringFunction.Name.Contains: literal, strict = self.options if not literal: if not strict: @@ -139,7 +141,7 @@ def _validate_input(self): raise NotImplementedError( f"Unsupported regex {pattern} for GPU engine." ) from e - elif self.name == StringFunctionName.Replace: + elif self.name == StringFunction.Name.Replace: _, literal = self.options if not literal: raise NotImplementedError("literal=False is not supported for replace") @@ -150,7 +152,7 @@ def _validate_input(self): raise NotImplementedError( "libcudf replace does not support empty strings" ) - elif self.name == StringFunctionName.ReplaceMany: + elif self.name == StringFunction.Name.ReplaceMany: (ascii_case_insensitive,) = self.options if ascii_case_insensitive: raise NotImplementedError( @@ -166,12 +168,12 @@ def _validate_input(self): "libcudf replace_many is implemented differently from polars " "for empty strings" ) - elif self.name == StringFunctionName.Slice: + elif self.name == StringFunction.Name.Slice: if not all(isinstance(child, Literal) for child in self.children[1:]): raise NotImplementedError( "Slice only supports literal start and stop values" ) - elif self.name == StringFunctionName.Strptime: + elif self.name == StringFunction.Name.Strptime: format, _, exact, cache = self.options if cache: raise NotImplementedError("Strptime cache is a CPU feature") @@ -180,9 +182,9 @@ def _validate_input(self): if not exact: raise NotImplementedError("Strptime does not support exact=False") elif self.name in { - StringFunctionName.StripChars, - StringFunctionName.StripCharsStart, - StringFunctionName.StripCharsEnd, + StringFunction.Name.StripChars, + StringFunction.Name.StripCharsStart, + StringFunction.Name.StripCharsEnd, }: if not isinstance(self.children[1], Literal): raise NotImplementedError( @@ -197,7 +199,7 @@ def do_evaluate( mapping: Mapping[Expr, Column] | None = None, ) -> Column: """Evaluate this expression given a dataframe for context.""" - if self.name == StringFunctionName.Contains: + if self.name == StringFunction.Name.Contains: child, arg = self.children column = child.evaluate(df, context=context, mapping=mapping) @@ -214,7 +216,7 @@ def do_evaluate( return Column( plc.strings.contains.contains_re(column.obj, self._regex_program) ) - elif self.name == StringFunctionName.Slice: + elif self.name == StringFunction.Name.Slice: child, expr_offset, expr_length = self.children assert isinstance(expr_offset, Literal) assert isinstance(expr_length, Literal) @@ -245,16 +247,16 @@ def do_evaluate( ) ) elif self.name in { - StringFunctionName.StripChars, - StringFunctionName.StripCharsStart, - StringFunctionName.StripCharsEnd, + StringFunction.Name.StripChars, + StringFunction.Name.StripCharsStart, + StringFunction.Name.StripCharsEnd, }: column, chars = ( c.evaluate(df, context=context, mapping=mapping) for c in self.children ) - if self.name == StringFunctionName.StripCharsStart: + if self.name == StringFunction.Name.StripCharsStart: side = plc.strings.SideType.LEFT - elif self.name == StringFunctionName.StripCharsEnd: + elif self.name == StringFunction.Name.StripCharsEnd: side = plc.strings.SideType.RIGHT else: side = plc.strings.SideType.BOTH @@ -264,13 +266,13 @@ def do_evaluate( child.evaluate(df, context=context, mapping=mapping) for child in self.children ] - if self.name == StringFunctionName.Lowercase: + if self.name == StringFunction.Name.Lowercase: (column,) = columns return Column(plc.strings.case.to_lower(column.obj)) - elif self.name == StringFunctionName.Uppercase: + elif self.name == StringFunction.Name.Uppercase: (column,) = columns return Column(plc.strings.case.to_upper(column.obj)) - elif self.name == StringFunctionName.EndsWith: + elif self.name == StringFunction.Name.EndsWith: column, suffix = columns return Column( plc.strings.find.ends_with( @@ -280,7 +282,7 @@ def do_evaluate( else suffix.obj, ) ) - elif self.name == StringFunctionName.StartsWith: + elif self.name == StringFunction.Name.StartsWith: column, prefix = columns return Column( plc.strings.find.starts_with( @@ -290,7 +292,7 @@ def do_evaluate( else prefix.obj, ) ) - elif self.name == StringFunctionName.Strptime: + elif self.name == StringFunction.Name.Strptime: # TODO: ignores ambiguous format, strict, exact, cache = self.options col = self.children[0].evaluate(df, context=context, mapping=mapping) @@ -322,7 +324,7 @@ def do_evaluate( res.columns()[0], self.dtype, format ) ) - elif self.name == StringFunctionName.Replace: + elif self.name == StringFunction.Name.Replace: column, target, repl = columns n, _ = self.options return Column( @@ -330,7 +332,7 @@ def do_evaluate( column.obj, target.obj_scalar, repl.obj_scalar, maxrepl=n ) ) - elif self.name == StringFunctionName.ReplaceMany: + elif self.name == StringFunction.Name.ReplaceMany: column, target, repl = columns return Column( plc.strings.replace.replace_multiple(column.obj, target.obj, repl.obj) diff --git a/python/cudf_polars/cudf_polars/dsl/to_ast.py b/python/cudf_polars/cudf_polars/dsl/to_ast.py index 4c9313dfecc..0ec54a52d43 100644 --- a/python/cudf_polars/cudf_polars/dsl/to_ast.py +++ b/python/cudf_polars/cudf_polars/dsl/to_ast.py @@ -12,7 +12,6 @@ from pylibcudf import expressions as plc_expr from cudf_polars.dsl import expr -from cudf_polars.dsl.expressions.boolean import BooleanFunctionName from cudf_polars.dsl.traversal import CachingVisitor, reuse_if_unchanged from cudf_polars.typing import GenericTransformer @@ -184,7 +183,7 @@ def _(node: expr.BinOp, self: Transformer) -> plc_expr.Expression: @_to_ast.register def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: - if node.name == BooleanFunctionName.IsIn: + if node.name == expr.BooleanFunction.Name.IsIn: needles, haystack = node.children if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16: # 16 is an arbitrary limit @@ -203,14 +202,14 @@ def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: raise NotImplementedError( f"Parquet filters don't support {node.name} on columns" ) - if node.name == BooleanFunctionName.IsNull: + if node.name == expr.BooleanFunction.Name.IsNull: return plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])) - elif node.name == BooleanFunctionName.IsNotNull: + elif node.name == expr.BooleanFunction.Name.IsNotNull: return plc_expr.Operation( plc_expr.ASTOperator.NOT, plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])), ) - elif node.name == BooleanFunctionName.Not: + elif node.name == expr.BooleanFunction.Name.Not: return plc_expr.Operation(plc_expr.ASTOperator.NOT, self(node.children[0])) raise NotImplementedError(f"AST conversion does not support {node.name}") diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 76b1f92fc0b..2727019ff7f 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -21,9 +21,6 @@ import pylibcudf as plc from cudf_polars.dsl import expr, ir -from cudf_polars.dsl.expressions.boolean import BooleanFunctionName -from cudf_polars.dsl.expressions.datetime import TemporalFunctionName -from cudf_polars.dsl.expressions.string import StringFunctionName from cudf_polars.dsl.to_ast import insert_colrefs from cudf_polars.typing import NodeTraverser from cudf_polars.utils import dtypes, sorting @@ -535,11 +532,15 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex pa.scalar("", type=plc.interop.to_arrow(column.dtype)), ) return expr.StringFunction( - dtype, StringFunctionName.from_polars(name), options, column, chars + dtype, + expr.StringFunction.Name.from_polars(name), + options, + column, + chars, ) return expr.StringFunction( dtype, - StringFunctionName.from_polars(name), + expr.StringFunction.Name.from_polars(name), options, *(translator.translate_expr(n=n) for n in node.input), ) @@ -556,7 +557,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex ) return expr.BooleanFunction( dtype, - BooleanFunctionName.from_polars(name), + expr.BooleanFunction.Name.from_polars(name), options, *(translator.translate_expr(n=n) for n in node.input), ) @@ -576,7 +577,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex } result_expr = expr.TemporalFunction( dtype, - TemporalFunctionName.from_polars(name), + expr.TemporalFunction.Name.from_polars(name), options, *(translator.translate_expr(n=n) for n in node.input), )