From 6ea021cae81ebd30a28d30162de37ce7ca148c68 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 22 Nov 2024 02:52:33 -0800 Subject: [PATCH 01/11] Implement picklable `BooleanFunctionName` type --- .../cudf_polars/dsl/expressions/boolean.py | 68 +++++++++++++------ python/cudf_polars/cudf_polars/dsl/to_ast.py | 11 ++- .../cudf_polars/cudf_polars/dsl/translate.py | 3 +- 3 files changed, 54 insertions(+), 28 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py index 8db8172ebd1..294f1a86a63 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py @@ -6,13 +6,12 @@ from __future__ import annotations +from enum import Enum, auto from functools import partial, reduce from typing import TYPE_CHECKING, Any, ClassVar import pyarrow as pa -from polars.polars import _expr_nodes as pl_expr - import pylibcudf as plc from cudf_polars.containers import Column @@ -31,6 +30,33 @@ __all__ = ["BooleanFunction"] +class BooleanFunctionName(Enum): + All = auto() + AllHorizontal = auto() + Any = auto() + AnyHorizontal = auto() + IsBetween = auto() + IsDuplicated = auto() + IsFinite = auto() + IsFirstDistinct = auto() + IsIn = auto() + IsInfinite = auto() + IsLastDistinct = auto() + IsNan = auto() + IsNotNan = auto() + IsNotNull = auto() + IsNull = auto() + IsUnique = auto() + Not = auto() + + @staticmethod + def get_polars_type(tp: BooleanFunctionName): + function, name = str(tp).split(".") + if function != "BooleanFunction": + raise ValueError("BooleanFunction required") + return getattr(BooleanFunctionName, name) + + class BooleanFunction(Expr): __slots__ = ("name", "options") _non_child = ("dtype", "name", "options") @@ -38,7 +64,7 @@ class BooleanFunction(Expr): def __init__( self, dtype: plc.DataType, - name: pl_expr.BooleanFunction, + name: BooleanFunctionName, options: tuple[Any, ...], *children: Expr, ) -> None: @@ -46,7 +72,7 @@ def __init__( self.options = options self.name = name self.children = children - if self.name == pl_expr.BooleanFunction.IsIn and not all( + if self.name == BooleanFunctionName.IsIn and not all( c.dtype == self.children[0].dtype for c in self.children ): # TODO: If polars IR doesn't put the casts in, we need to @@ -110,12 +136,12 @@ def do_evaluate( ) -> Column: """Evaluate this expression given a dataframe for context.""" if self.name in ( - pl_expr.BooleanFunction.IsFinite, - pl_expr.BooleanFunction.IsInfinite, + BooleanFunctionName.IsFinite, + BooleanFunctionName.IsInfinite, ): # Avoid evaluating the child if the dtype tells us it's unnecessary. (child,) = self.children - is_finite = self.name == pl_expr.BooleanFunction.IsFinite + is_finite = self.name == BooleanFunctionName.IsFinite if child.dtype.id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64): value = plc.interop.from_arrow( pa.scalar(value=is_finite, type=plc.interop.to_arrow(self.dtype)) @@ -142,10 +168,10 @@ def do_evaluate( ] # Kleene logic for Any (OR) and All (AND) if ignore_nulls is # False - if self.name in (pl_expr.BooleanFunction.Any, pl_expr.BooleanFunction.All): + if self.name in (BooleanFunctionName.Any, BooleanFunctionName.All): (ignore_nulls,) = self.options (column,) = columns - is_any = self.name == pl_expr.BooleanFunction.Any + is_any = self.name == BooleanFunctionName.Any agg = plc.aggregation.any() if is_any else plc.aggregation.all() result = plc.reduce.reduce(column.obj, agg, self.dtype) if not ignore_nulls and column.obj.null_count() > 0: @@ -165,27 +191,27 @@ def do_evaluate( # False || Null => Null True && Null => Null return Column(plc.Column.all_null_like(column.obj, 1)) return Column(plc.Column.from_scalar(result, 1)) - if self.name == pl_expr.BooleanFunction.IsNull: + if self.name == BooleanFunctionName.IsNull: (column,) = columns return Column(plc.unary.is_null(column.obj)) - elif self.name == pl_expr.BooleanFunction.IsNotNull: + elif self.name == BooleanFunctionName.IsNotNull: (column,) = columns return Column(plc.unary.is_valid(column.obj)) - elif self.name == pl_expr.BooleanFunction.IsNan: + elif self.name == BooleanFunctionName.IsNan: (column,) = columns return Column( plc.unary.is_nan(column.obj).with_mask( column.obj.null_mask(), column.obj.null_count() ) ) - elif self.name == pl_expr.BooleanFunction.IsNotNan: + elif self.name == BooleanFunctionName.IsNotNan: (column,) = columns return Column( plc.unary.is_not_nan(column.obj).with_mask( column.obj.null_mask(), column.obj.null_count() ) ) - elif self.name == pl_expr.BooleanFunction.IsFirstDistinct: + elif self.name == BooleanFunctionName.IsFirstDistinct: (column,) = columns return self._distinct( column, @@ -197,7 +223,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == pl_expr.BooleanFunction.IsLastDistinct: + elif self.name == BooleanFunctionName.IsLastDistinct: (column,) = columns return self._distinct( column, @@ -209,7 +235,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == pl_expr.BooleanFunction.IsUnique: + elif self.name == BooleanFunctionName.IsUnique: (column,) = columns return self._distinct( column, @@ -221,7 +247,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == pl_expr.BooleanFunction.IsDuplicated: + elif self.name == BooleanFunctionName.IsDuplicated: (column,) = columns return self._distinct( column, @@ -233,7 +259,7 @@ def do_evaluate( pa.scalar(value=True, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == pl_expr.BooleanFunction.AllHorizontal: + elif self.name == BooleanFunctionName.AllHorizontal: return Column( reduce( partial( @@ -244,7 +270,7 @@ def do_evaluate( (c.obj for c in columns), ) ) - elif self.name == pl_expr.BooleanFunction.AnyHorizontal: + elif self.name == BooleanFunctionName.AnyHorizontal: return Column( reduce( partial( @@ -255,10 +281,10 @@ def do_evaluate( (c.obj for c in columns), ) ) - elif self.name == pl_expr.BooleanFunction.IsIn: + elif self.name == BooleanFunctionName.IsIn: needles, haystack = columns return Column(plc.search.contains(haystack.obj, needles.obj)) - elif self.name == pl_expr.BooleanFunction.Not: + elif self.name == BooleanFunctionName.Not: (column,) = columns return Column( plc.unary.unary_operation(column.obj, plc.unary.UnaryOperator.NOT) diff --git a/python/cudf_polars/cudf_polars/dsl/to_ast.py b/python/cudf_polars/cudf_polars/dsl/to_ast.py index acc4b3669af..4c9313dfecc 100644 --- a/python/cudf_polars/cudf_polars/dsl/to_ast.py +++ b/python/cudf_polars/cudf_polars/dsl/to_ast.py @@ -8,12 +8,11 @@ from functools import partial, reduce, singledispatch from typing import TYPE_CHECKING, TypeAlias -from polars.polars import _expr_nodes as pl_expr - import pylibcudf as plc from pylibcudf import expressions as plc_expr from cudf_polars.dsl import expr +from cudf_polars.dsl.expressions.boolean import BooleanFunctionName from cudf_polars.dsl.traversal import CachingVisitor, reuse_if_unchanged from cudf_polars.typing import GenericTransformer @@ -185,7 +184,7 @@ def _(node: expr.BinOp, self: Transformer) -> plc_expr.Expression: @_to_ast.register def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: - if node.name == pl_expr.BooleanFunction.IsIn: + if node.name == BooleanFunctionName.IsIn: needles, haystack = node.children if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16: # 16 is an arbitrary limit @@ -204,14 +203,14 @@ def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: raise NotImplementedError( f"Parquet filters don't support {node.name} on columns" ) - if node.name == pl_expr.BooleanFunction.IsNull: + if node.name == BooleanFunctionName.IsNull: return plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])) - elif node.name == pl_expr.BooleanFunction.IsNotNull: + elif node.name == BooleanFunctionName.IsNotNull: return plc_expr.Operation( plc_expr.ASTOperator.NOT, plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])), ) - elif node.name == pl_expr.BooleanFunction.Not: + elif node.name == BooleanFunctionName.Not: return plc_expr.Operation(plc_expr.ASTOperator.NOT, self(node.children[0])) raise NotImplementedError(f"AST conversion does not support {node.name}") diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 12fc2a196cd..9c88465ac47 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -21,6 +21,7 @@ import pylibcudf as plc from cudf_polars.dsl import expr, ir +from cudf_polars.dsl.expressions.boolean import BooleanFunctionName from cudf_polars.dsl.to_ast import insert_colrefs from cudf_polars.typing import NodeTraverser from cudf_polars.utils import dtypes, sorting @@ -551,7 +552,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex ) return expr.BooleanFunction( dtype, - name, + BooleanFunctionName.get_polars_type(name), options, *(translator.translate_expr(n=n) for n in node.input), ) From 9f2e5c89b6487612afae3e793fd83f589eefd4ce Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 22 Nov 2024 03:23:42 -0800 Subject: [PATCH 02/11] Implement picklable `StringFunctionName` type --- .../cudf_polars/dsl/expressions/string.py | 125 +++++++++++++----- .../cudf_polars/cudf_polars/dsl/translate.py | 7 +- 2 files changed, 94 insertions(+), 38 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/string.py b/python/cudf_polars/cudf_polars/dsl/expressions/string.py index 8b66c9d4676..d773d5bcf40 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/string.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/string.py @@ -6,13 +6,13 @@ from __future__ import annotations +from enum import Enum, auto from typing import TYPE_CHECKING, Any import pyarrow as pa import pyarrow.compute as pc from polars.exceptions import InvalidOperationError -from polars.polars import _expr_nodes as pl_expr import pylibcudf as plc @@ -28,6 +28,59 @@ __all__ = ["StringFunction"] +class StringFunctionName(Enum): + Base64Decode = auto() + Base64Encode = auto() + ConcatHorizontal = auto() + ConcatVertical = auto() + Contains = auto() + ContainsMany = auto() + CountMatches = auto() + EndsWith = auto() + EscapeRegex = auto() + Extract = auto() + ExtractAll = auto() + ExtractGroups = auto() + Find = auto() + Head = auto() + HexDecode = auto() + HexEncode = auto() + JsonDecode = auto() + JsonPathMatch = auto() + LenBytes = auto() + LenChars = auto() + Lowercase = auto() + PadEnd = auto() + PadStart = auto() + Replace = auto() + ReplaceMany = auto() + Reverse = auto() + Slice = auto() + Split = auto() + SplitExact = auto() + SplitN = auto() + StartsWith = auto() + StripChars = auto() + StripCharsEnd = auto() + StripCharsStart = auto() + StripPrefix = auto() + StripSuffix = auto() + Strptime = auto() + Tail = auto() + Titlecase = auto() + ToDecimal = auto() + ToInteger = auto() + Uppercase = auto() + ZFill = auto() + + @staticmethod + def get_polars_type(tp: StringFunctionName): + function, name = str(tp).split(".") + if function != "StringFunction": + raise ValueError("StringFunction required") + return getattr(StringFunctionName, name) + + class StringFunction(Expr): __slots__ = ("name", "options", "_regex_program") _non_child = ("dtype", "name", "options") @@ -35,7 +88,7 @@ class StringFunction(Expr): def __init__( self, dtype: plc.DataType, - name: pl_expr.StringFunction, + name: StringFunctionName, options: tuple[Any, ...], *children: Expr, ) -> None: @@ -47,21 +100,21 @@ def __init__( def _validate_input(self): if self.name not in ( - pl_expr.StringFunction.Contains, - pl_expr.StringFunction.EndsWith, - pl_expr.StringFunction.Lowercase, - pl_expr.StringFunction.Replace, - pl_expr.StringFunction.ReplaceMany, - pl_expr.StringFunction.Slice, - pl_expr.StringFunction.Strptime, - pl_expr.StringFunction.StartsWith, - pl_expr.StringFunction.StripChars, - pl_expr.StringFunction.StripCharsStart, - pl_expr.StringFunction.StripCharsEnd, - pl_expr.StringFunction.Uppercase, + StringFunctionName.Contains, + StringFunctionName.EndsWith, + StringFunctionName.Lowercase, + StringFunctionName.Replace, + StringFunctionName.ReplaceMany, + StringFunctionName.Slice, + StringFunctionName.Strptime, + StringFunctionName.StartsWith, + StringFunctionName.StripChars, + StringFunctionName.StripCharsStart, + StringFunctionName.StripCharsEnd, + StringFunctionName.Uppercase, ): raise NotImplementedError(f"String function {self.name}") - if self.name == pl_expr.StringFunction.Contains: + if self.name == StringFunctionName.Contains: literal, strict = self.options if not literal: if not strict: @@ -82,7 +135,7 @@ def _validate_input(self): raise NotImplementedError( f"Unsupported regex {pattern} for GPU engine." ) from e - elif self.name == pl_expr.StringFunction.Replace: + elif self.name == StringFunctionName.Replace: _, literal = self.options if not literal: raise NotImplementedError("literal=False is not supported for replace") @@ -93,7 +146,7 @@ def _validate_input(self): raise NotImplementedError( "libcudf replace does not support empty strings" ) - elif self.name == pl_expr.StringFunction.ReplaceMany: + elif self.name == StringFunctionName.ReplaceMany: (ascii_case_insensitive,) = self.options if ascii_case_insensitive: raise NotImplementedError( @@ -109,12 +162,12 @@ def _validate_input(self): "libcudf replace_many is implemented differently from polars " "for empty strings" ) - elif self.name == pl_expr.StringFunction.Slice: + elif self.name == StringFunctionName.Slice: if not all(isinstance(child, Literal) for child in self.children[1:]): raise NotImplementedError( "Slice only supports literal start and stop values" ) - elif self.name == pl_expr.StringFunction.Strptime: + elif self.name == StringFunctionName.Strptime: format, _, exact, cache = self.options if cache: raise NotImplementedError("Strptime cache is a CPU feature") @@ -123,9 +176,9 @@ def _validate_input(self): if not exact: raise NotImplementedError("Strptime does not support exact=False") elif self.name in { - pl_expr.StringFunction.StripChars, - pl_expr.StringFunction.StripCharsStart, - pl_expr.StringFunction.StripCharsEnd, + StringFunctionName.StripChars, + StringFunctionName.StripCharsStart, + StringFunctionName.StripCharsEnd, }: if not isinstance(self.children[1], Literal): raise NotImplementedError( @@ -140,7 +193,7 @@ def do_evaluate( mapping: Mapping[Expr, Column] | None = None, ) -> Column: """Evaluate this expression given a dataframe for context.""" - if self.name == pl_expr.StringFunction.Contains: + if self.name == StringFunctionName.Contains: child, arg = self.children column = child.evaluate(df, context=context, mapping=mapping) @@ -157,7 +210,7 @@ def do_evaluate( return Column( plc.strings.contains.contains_re(column.obj, self._regex_program) ) - elif self.name == pl_expr.StringFunction.Slice: + elif self.name == StringFunctionName.Slice: child, expr_offset, expr_length = self.children assert isinstance(expr_offset, Literal) assert isinstance(expr_length, Literal) @@ -188,16 +241,16 @@ def do_evaluate( ) ) elif self.name in { - pl_expr.StringFunction.StripChars, - pl_expr.StringFunction.StripCharsStart, - pl_expr.StringFunction.StripCharsEnd, + StringFunctionName.StripChars, + StringFunctionName.StripCharsStart, + StringFunctionName.StripCharsEnd, }: column, chars = ( c.evaluate(df, context=context, mapping=mapping) for c in self.children ) - if self.name == pl_expr.StringFunction.StripCharsStart: + if self.name == StringFunctionName.StripCharsStart: side = plc.strings.SideType.LEFT - elif self.name == pl_expr.StringFunction.StripCharsEnd: + elif self.name == StringFunctionName.StripCharsEnd: side = plc.strings.SideType.RIGHT else: side = plc.strings.SideType.BOTH @@ -207,13 +260,13 @@ def do_evaluate( child.evaluate(df, context=context, mapping=mapping) for child in self.children ] - if self.name == pl_expr.StringFunction.Lowercase: + if self.name == StringFunctionName.Lowercase: (column,) = columns return Column(plc.strings.case.to_lower(column.obj)) - elif self.name == pl_expr.StringFunction.Uppercase: + elif self.name == StringFunctionName.Uppercase: (column,) = columns return Column(plc.strings.case.to_upper(column.obj)) - elif self.name == pl_expr.StringFunction.EndsWith: + elif self.name == StringFunctionName.EndsWith: column, suffix = columns return Column( plc.strings.find.ends_with( @@ -223,7 +276,7 @@ def do_evaluate( else suffix.obj, ) ) - elif self.name == pl_expr.StringFunction.StartsWith: + elif self.name == StringFunctionName.StartsWith: column, prefix = columns return Column( plc.strings.find.starts_with( @@ -233,7 +286,7 @@ def do_evaluate( else prefix.obj, ) ) - elif self.name == pl_expr.StringFunction.Strptime: + elif self.name == StringFunctionName.Strptime: # TODO: ignores ambiguous format, strict, exact, cache = self.options col = self.children[0].evaluate(df, context=context, mapping=mapping) @@ -265,7 +318,7 @@ def do_evaluate( res.columns()[0], self.dtype, format ) ) - elif self.name == pl_expr.StringFunction.Replace: + elif self.name == StringFunctionName.Replace: column, target, repl = columns n, _ = self.options return Column( @@ -273,7 +326,7 @@ def do_evaluate( column.obj, target.obj_scalar, repl.obj_scalar, maxrepl=n ) ) - elif self.name == pl_expr.StringFunction.ReplaceMany: + elif self.name == StringFunctionName.ReplaceMany: column, target, repl = columns return Column( plc.strings.replace.replace_multiple(column.obj, target.obj, repl.obj) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 9c88465ac47..ceed2bde2a8 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -22,6 +22,7 @@ from cudf_polars.dsl import expr, ir from cudf_polars.dsl.expressions.boolean import BooleanFunctionName +from cudf_polars.dsl.expressions.string import StringFunctionName from cudf_polars.dsl.to_ast import insert_colrefs from cudf_polars.typing import NodeTraverser from cudf_polars.utils import dtypes, sorting @@ -532,10 +533,12 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex column.dtype, pa.scalar("", type=plc.interop.to_arrow(column.dtype)), ) - return expr.StringFunction(dtype, name, options, column, chars) + return expr.StringFunction( + dtype, StringFunctionName.get_polars_type(name), options, column, chars + ) return expr.StringFunction( dtype, - name, + StringFunctionName.get_polars_type(name), options, *(translator.translate_expr(n=n) for n in node.input), ) From d142f50128dd0a484f50078406dfd25ae013fd63 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 22 Nov 2024 03:25:08 -0800 Subject: [PATCH 03/11] Implement picklable `TemporalFunctionName` type --- .../cudf_polars/dsl/expressions/datetime.py | 84 +++++++++++++++---- .../cudf_polars/cudf_polars/dsl/translate.py | 3 +- 2 files changed, 70 insertions(+), 17 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py index cd8e5c6a4eb..d917c484e5c 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py @@ -6,12 +6,11 @@ from __future__ import annotations +from enum import Enum, auto from typing import TYPE_CHECKING, Any, ClassVar import pyarrow as pa -from polars.polars import _expr_nodes as pl_expr - import pylibcudf as plc from cudf_polars.containers import Column @@ -25,28 +24,81 @@ __all__ = ["TemporalFunction"] +class TemporalFunctionName(Enum): + BaseUtcOffset = auto() + CastTimeUnit = auto() + Century = auto() + Combine = auto() + ConvertTimeZone = auto() + DSTOffset = auto() + Date = auto() + Datetime = auto() + DatetimeFunction = auto() + Day = auto() + Duration = auto() + Hour = auto() + IsLeapYear = auto() + IsoYear = auto() + Microsecond = auto() + Millennium = auto() + Millisecond = auto() + Minute = auto() + Month = auto() + MonthEnd = auto() + MonthStart = auto() + Nanosecond = auto() + OffsetBy = auto() + OrdinalDay = auto() + Quarter = auto() + ReplaceTimeZone = auto() + Round = auto() + Second = auto() + Time = auto() + TimeStamp = auto() + ToString = auto() + TotalDays = auto() + TotalHours = auto() + TotalMicroseconds = auto() + TotalMilliseconds = auto() + TotalMinutes = auto() + TotalNanoseconds = auto() + TotalSeconds = auto() + Truncate = auto() + Week = auto() + WeekDay = auto() + WithTimeUnit = auto() + Year = auto() + + @staticmethod + def get_polars_type(tp: TemporalFunctionName): + function, name = str(tp).split(".") + if function != "TemporalFunction": + raise ValueError("TemporalFunction required") + return getattr(TemporalFunctionName, name) + + class TemporalFunction(Expr): __slots__ = ("name", "options") _COMPONENT_MAP: ClassVar[ - dict[pl_expr.TemporalFunction, plc.datetime.DatetimeComponent] + dict[TemporalFunctionName, plc.datetime.DatetimeComponent] ] = { - pl_expr.TemporalFunction.Year: plc.datetime.DatetimeComponent.YEAR, - pl_expr.TemporalFunction.Month: plc.datetime.DatetimeComponent.MONTH, - pl_expr.TemporalFunction.Day: plc.datetime.DatetimeComponent.DAY, - pl_expr.TemporalFunction.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY, - pl_expr.TemporalFunction.Hour: plc.datetime.DatetimeComponent.HOUR, - pl_expr.TemporalFunction.Minute: plc.datetime.DatetimeComponent.MINUTE, - pl_expr.TemporalFunction.Second: plc.datetime.DatetimeComponent.SECOND, - pl_expr.TemporalFunction.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND, - pl_expr.TemporalFunction.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND, - pl_expr.TemporalFunction.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND, + TemporalFunctionName.Year: plc.datetime.DatetimeComponent.YEAR, + TemporalFunctionName.Month: plc.datetime.DatetimeComponent.MONTH, + TemporalFunctionName.Day: plc.datetime.DatetimeComponent.DAY, + TemporalFunctionName.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY, + TemporalFunctionName.Hour: plc.datetime.DatetimeComponent.HOUR, + TemporalFunctionName.Minute: plc.datetime.DatetimeComponent.MINUTE, + TemporalFunctionName.Second: plc.datetime.DatetimeComponent.SECOND, + TemporalFunctionName.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND, + TemporalFunctionName.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND, + TemporalFunctionName.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND, } _non_child = ("dtype", "name", "options") def __init__( self, dtype: plc.DataType, - name: pl_expr.TemporalFunction, + name: TemporalFunctionName, options: tuple[Any, ...], *children: Expr, ) -> None: @@ -70,7 +122,7 @@ def do_evaluate( for child in self.children ] (column,) = columns - if self.name == pl_expr.TemporalFunction.Microsecond: + if self.name == TemporalFunctionName.Microsecond: millis = plc.datetime.extract_datetime_component( column.obj, plc.datetime.DatetimeComponent.MILLISECOND ) @@ -90,7 +142,7 @@ def do_evaluate( plc.types.DataType(plc.types.TypeId.INT32), ) return Column(total_micros) - elif self.name == pl_expr.TemporalFunction.Nanosecond: + elif self.name == TemporalFunctionName.Nanosecond: millis = plc.datetime.extract_datetime_component( column.obj, plc.datetime.DatetimeComponent.MILLISECOND ) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index ceed2bde2a8..23e5c963abb 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -22,6 +22,7 @@ from cudf_polars.dsl import expr, ir from cudf_polars.dsl.expressions.boolean import BooleanFunctionName +from cudf_polars.dsl.expressions.datetime import TemporalFunctionName from cudf_polars.dsl.expressions.string import StringFunctionName from cudf_polars.dsl.to_ast import insert_colrefs from cudf_polars.typing import NodeTraverser @@ -575,7 +576,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex } result_expr = expr.TemporalFunction( dtype, - name, + TemporalFunctionName.get_polars_type(name), options, *(translator.translate_expr(n=n) for n in node.input), ) From be7fa529ad07f88ff61086dcbd2c86903ec04033 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 22 Nov 2024 07:48:23 -0800 Subject: [PATCH 04/11] Rename method to `from_polars` and mark `@classmethod` --- .../cudf_polars/dsl/expressions/boolean.py | 11 +++++++---- .../cudf_polars/dsl/expressions/datetime.py | 12 ++++++++---- .../cudf_polars/dsl/expressions/string.py | 12 ++++++++---- python/cudf_polars/cudf_polars/dsl/translate.py | 8 ++++---- 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py index 294f1a86a63..ac8169d516c 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py @@ -23,7 +23,10 @@ if TYPE_CHECKING: from collections.abc import Mapping + from typing_extensions import Self + import polars.type_aliases as pl_types + from polars.polars import _expr_nodes as pl_expr from cudf_polars.containers import DataFrame @@ -49,12 +52,12 @@ class BooleanFunctionName(Enum): IsUnique = auto() Not = auto() - @staticmethod - def get_polars_type(tp: BooleanFunctionName): - function, name = str(tp).split(".") + @classmethod + def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self: + function, name = str(obj).split(".", maxsplit=1) if function != "BooleanFunction": raise ValueError("BooleanFunction required") - return getattr(BooleanFunctionName, name) + return getattr(cls, name) class BooleanFunction(Expr): diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py index d917c484e5c..7ee53f52306 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py @@ -19,6 +19,10 @@ if TYPE_CHECKING: from collections.abc import Mapping + from typing_extensions import Self + + from polars.polars import _expr_nodes as pl_expr + from cudf_polars.containers import DataFrame __all__ = ["TemporalFunction"] @@ -69,12 +73,12 @@ class TemporalFunctionName(Enum): WithTimeUnit = auto() Year = auto() - @staticmethod - def get_polars_type(tp: TemporalFunctionName): - function, name = str(tp).split(".") + @classmethod + def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self: + function, name = str(obj).split(".", maxsplit=1) if function != "TemporalFunction": raise ValueError("TemporalFunction required") - return getattr(TemporalFunctionName, name) + return getattr(cls, name) class TemporalFunction(Expr): diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/string.py b/python/cudf_polars/cudf_polars/dsl/expressions/string.py index d773d5bcf40..4368fc40bd3 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/string.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/string.py @@ -23,6 +23,10 @@ if TYPE_CHECKING: from collections.abc import Mapping + from typing_extensions import Self + + from polars.polars import _expr_nodes as pl_expr + from cudf_polars.containers import DataFrame __all__ = ["StringFunction"] @@ -73,12 +77,12 @@ class StringFunctionName(Enum): Uppercase = auto() ZFill = auto() - @staticmethod - def get_polars_type(tp: StringFunctionName): - function, name = str(tp).split(".") + @classmethod + def from_polars(cls, obj: pl_expr.StringFunction) -> Self: + function, name = str(obj).split(".", maxsplit=1) if function != "StringFunction": raise ValueError("StringFunction required") - return getattr(StringFunctionName, name) + return getattr(cls, name) class StringFunction(Expr): diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 23e5c963abb..76b1f92fc0b 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -535,11 +535,11 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex pa.scalar("", type=plc.interop.to_arrow(column.dtype)), ) return expr.StringFunction( - dtype, StringFunctionName.get_polars_type(name), options, column, chars + dtype, StringFunctionName.from_polars(name), options, column, chars ) return expr.StringFunction( dtype, - StringFunctionName.get_polars_type(name), + StringFunctionName.from_polars(name), options, *(translator.translate_expr(n=n) for n in node.input), ) @@ -556,7 +556,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex ) return expr.BooleanFunction( dtype, - BooleanFunctionName.get_polars_type(name), + BooleanFunctionName.from_polars(name), options, *(translator.translate_expr(n=n) for n in node.input), ) @@ -576,7 +576,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex } result_expr = expr.TemporalFunction( dtype, - TemporalFunctionName.get_polars_type(name), + TemporalFunctionName.from_polars(name), options, *(translator.translate_expr(n=n) for n in node.input), ) From 3f69d98ed7e525aef6ac1456a5586a1ea12f28ff Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 22 Nov 2024 08:33:20 -0800 Subject: [PATCH 05/11] Move types to internal classes `Name` and use `IntEnum` --- .../cudf_polars/dsl/expressions/boolean.py | 92 ++++----- .../cudf_polars/dsl/expressions/datetime.py | 142 +++++++------- .../cudf_polars/dsl/expressions/string.py | 176 +++++++++--------- python/cudf_polars/cudf_polars/dsl/to_ast.py | 9 +- .../cudf_polars/cudf_polars/dsl/translate.py | 15 +- 5 files changed, 219 insertions(+), 215 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py index ac8169d516c..9cccd4d97d7 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py @@ -6,7 +6,7 @@ from __future__ import annotations -from enum import Enum, auto +from enum import IntEnum, auto from functools import partial, reduce from typing import TYPE_CHECKING, Any, ClassVar @@ -33,41 +33,43 @@ __all__ = ["BooleanFunction"] -class BooleanFunctionName(Enum): - All = auto() - AllHorizontal = auto() - Any = auto() - AnyHorizontal = auto() - IsBetween = auto() - IsDuplicated = auto() - IsFinite = auto() - IsFirstDistinct = auto() - IsIn = auto() - IsInfinite = auto() - IsLastDistinct = auto() - IsNan = auto() - IsNotNan = auto() - IsNotNull = auto() - IsNull = auto() - IsUnique = auto() - Not = auto() +class BooleanFunction(Expr): + class Name(IntEnum): + """Internal and picklable representation of polars' `BooleanFunction`.""" - @classmethod - def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self: - function, name = str(obj).split(".", maxsplit=1) - if function != "BooleanFunction": - raise ValueError("BooleanFunction required") - return getattr(cls, name) + All = auto() + AllHorizontal = auto() + Any = auto() + AnyHorizontal = auto() + IsBetween = auto() + IsDuplicated = auto() + IsFinite = auto() + IsFirstDistinct = auto() + IsIn = auto() + IsInfinite = auto() + IsLastDistinct = auto() + IsNan = auto() + IsNotNan = auto() + IsNotNull = auto() + IsNull = auto() + IsUnique = auto() + Not = auto() + @classmethod + def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self: + """Convert from polars' `BooleanFunction`.""" + function, name = str(obj).split(".", maxsplit=1) + if function != "BooleanFunction": + raise ValueError("BooleanFunction required") + return getattr(cls, name) -class BooleanFunction(Expr): __slots__ = ("name", "options") _non_child = ("dtype", "name", "options") def __init__( self, dtype: plc.DataType, - name: BooleanFunctionName, + name: BooleanFunction.Name, options: tuple[Any, ...], *children: Expr, ) -> None: @@ -75,7 +77,7 @@ def __init__( self.options = options self.name = name self.children = children - if self.name == BooleanFunctionName.IsIn and not all( + if self.name == BooleanFunction.Name.IsIn and not all( c.dtype == self.children[0].dtype for c in self.children ): # TODO: If polars IR doesn't put the casts in, we need to @@ -139,12 +141,12 @@ def do_evaluate( ) -> Column: """Evaluate this expression given a dataframe for context.""" if self.name in ( - BooleanFunctionName.IsFinite, - BooleanFunctionName.IsInfinite, + BooleanFunction.Name.IsFinite, + BooleanFunction.Name.IsInfinite, ): # Avoid evaluating the child if the dtype tells us it's unnecessary. (child,) = self.children - is_finite = self.name == BooleanFunctionName.IsFinite + is_finite = self.name == BooleanFunction.Name.IsFinite if child.dtype.id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64): value = plc.interop.from_arrow( pa.scalar(value=is_finite, type=plc.interop.to_arrow(self.dtype)) @@ -171,10 +173,10 @@ def do_evaluate( ] # Kleene logic for Any (OR) and All (AND) if ignore_nulls is # False - if self.name in (BooleanFunctionName.Any, BooleanFunctionName.All): + if self.name in (BooleanFunction.Name.Any, BooleanFunction.Name.All): (ignore_nulls,) = self.options (column,) = columns - is_any = self.name == BooleanFunctionName.Any + is_any = self.name == BooleanFunction.Name.Any agg = plc.aggregation.any() if is_any else plc.aggregation.all() result = plc.reduce.reduce(column.obj, agg, self.dtype) if not ignore_nulls and column.obj.null_count() > 0: @@ -194,27 +196,27 @@ def do_evaluate( # False || Null => Null True && Null => Null return Column(plc.Column.all_null_like(column.obj, 1)) return Column(plc.Column.from_scalar(result, 1)) - if self.name == BooleanFunctionName.IsNull: + if self.name == BooleanFunction.Name.IsNull: (column,) = columns return Column(plc.unary.is_null(column.obj)) - elif self.name == BooleanFunctionName.IsNotNull: + elif self.name == BooleanFunction.Name.IsNotNull: (column,) = columns return Column(plc.unary.is_valid(column.obj)) - elif self.name == BooleanFunctionName.IsNan: + elif self.name == BooleanFunction.Name.IsNan: (column,) = columns return Column( plc.unary.is_nan(column.obj).with_mask( column.obj.null_mask(), column.obj.null_count() ) ) - elif self.name == BooleanFunctionName.IsNotNan: + elif self.name == BooleanFunction.Name.IsNotNan: (column,) = columns return Column( plc.unary.is_not_nan(column.obj).with_mask( column.obj.null_mask(), column.obj.null_count() ) ) - elif self.name == BooleanFunctionName.IsFirstDistinct: + elif self.name == BooleanFunction.Name.IsFirstDistinct: (column,) = columns return self._distinct( column, @@ -226,7 +228,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == BooleanFunctionName.IsLastDistinct: + elif self.name == BooleanFunction.Name.IsLastDistinct: (column,) = columns return self._distinct( column, @@ -238,7 +240,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == BooleanFunctionName.IsUnique: + elif self.name == BooleanFunction.Name.IsUnique: (column,) = columns return self._distinct( column, @@ -250,7 +252,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == BooleanFunctionName.IsDuplicated: + elif self.name == BooleanFunction.Name.IsDuplicated: (column,) = columns return self._distinct( column, @@ -262,7 +264,7 @@ def do_evaluate( pa.scalar(value=True, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == BooleanFunctionName.AllHorizontal: + elif self.name == BooleanFunction.Name.AllHorizontal: return Column( reduce( partial( @@ -273,7 +275,7 @@ def do_evaluate( (c.obj for c in columns), ) ) - elif self.name == BooleanFunctionName.AnyHorizontal: + elif self.name == BooleanFunction.Name.AnyHorizontal: return Column( reduce( partial( @@ -284,10 +286,10 @@ def do_evaluate( (c.obj for c in columns), ) ) - elif self.name == BooleanFunctionName.IsIn: + elif self.name == BooleanFunction.Name.IsIn: needles, haystack = columns return Column(plc.search.contains(haystack.obj, needles.obj)) - elif self.name == BooleanFunctionName.Not: + elif self.name == BooleanFunction.Name.Not: (column,) = columns return Column( plc.unary.unary_operation(column.obj, plc.unary.UnaryOperator.NOT) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py index 7ee53f52306..d37318eb71b 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py @@ -6,7 +6,7 @@ from __future__ import annotations -from enum import Enum, auto +from enum import IntEnum, auto from typing import TYPE_CHECKING, Any, ClassVar import pyarrow as pa @@ -28,81 +28,81 @@ __all__ = ["TemporalFunction"] -class TemporalFunctionName(Enum): - BaseUtcOffset = auto() - CastTimeUnit = auto() - Century = auto() - Combine = auto() - ConvertTimeZone = auto() - DSTOffset = auto() - Date = auto() - Datetime = auto() - DatetimeFunction = auto() - Day = auto() - Duration = auto() - Hour = auto() - IsLeapYear = auto() - IsoYear = auto() - Microsecond = auto() - Millennium = auto() - Millisecond = auto() - Minute = auto() - Month = auto() - MonthEnd = auto() - MonthStart = auto() - Nanosecond = auto() - OffsetBy = auto() - OrdinalDay = auto() - Quarter = auto() - ReplaceTimeZone = auto() - Round = auto() - Second = auto() - Time = auto() - TimeStamp = auto() - ToString = auto() - TotalDays = auto() - TotalHours = auto() - TotalMicroseconds = auto() - TotalMilliseconds = auto() - TotalMinutes = auto() - TotalNanoseconds = auto() - TotalSeconds = auto() - Truncate = auto() - Week = auto() - WeekDay = auto() - WithTimeUnit = auto() - Year = auto() - - @classmethod - def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self: - function, name = str(obj).split(".", maxsplit=1) - if function != "TemporalFunction": - raise ValueError("TemporalFunction required") - return getattr(cls, name) - - class TemporalFunction(Expr): + class Name(IntEnum): + """Internal and picklable representation of polars' `TemporalFunction`.""" + + BaseUtcOffset = auto() + CastTimeUnit = auto() + Century = auto() + Combine = auto() + ConvertTimeZone = auto() + DSTOffset = auto() + Date = auto() + Datetime = auto() + DatetimeFunction = auto() + Day = auto() + Duration = auto() + Hour = auto() + IsLeapYear = auto() + IsoYear = auto() + Microsecond = auto() + Millennium = auto() + Millisecond = auto() + Minute = auto() + Month = auto() + MonthEnd = auto() + MonthStart = auto() + Nanosecond = auto() + OffsetBy = auto() + OrdinalDay = auto() + Quarter = auto() + ReplaceTimeZone = auto() + Round = auto() + Second = auto() + Time = auto() + TimeStamp = auto() + ToString = auto() + TotalDays = auto() + TotalHours = auto() + TotalMicroseconds = auto() + TotalMilliseconds = auto() + TotalMinutes = auto() + TotalNanoseconds = auto() + TotalSeconds = auto() + Truncate = auto() + Week = auto() + WeekDay = auto() + WithTimeUnit = auto() + Year = auto() + + @classmethod + def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self: + """Convert from polars' `TemporalFunction`.""" + function, name = str(obj).split(".", maxsplit=1) + if function != "TemporalFunction": + raise ValueError("TemporalFunction required") + return getattr(cls, name) + __slots__ = ("name", "options") - _COMPONENT_MAP: ClassVar[ - dict[TemporalFunctionName, plc.datetime.DatetimeComponent] - ] = { - TemporalFunctionName.Year: plc.datetime.DatetimeComponent.YEAR, - TemporalFunctionName.Month: plc.datetime.DatetimeComponent.MONTH, - TemporalFunctionName.Day: plc.datetime.DatetimeComponent.DAY, - TemporalFunctionName.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY, - TemporalFunctionName.Hour: plc.datetime.DatetimeComponent.HOUR, - TemporalFunctionName.Minute: plc.datetime.DatetimeComponent.MINUTE, - TemporalFunctionName.Second: plc.datetime.DatetimeComponent.SECOND, - TemporalFunctionName.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND, - TemporalFunctionName.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND, - TemporalFunctionName.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND, - } _non_child = ("dtype", "name", "options") + _COMPONENT_MAP: ClassVar[dict[Name, plc.datetime.DatetimeComponent]] = { + Name.Year: plc.datetime.DatetimeComponent.YEAR, + Name.Month: plc.datetime.DatetimeComponent.MONTH, + Name.Day: plc.datetime.DatetimeComponent.DAY, + Name.WeekDay: plc.datetime.DatetimeComponent.WEEKDAY, + Name.Hour: plc.datetime.DatetimeComponent.HOUR, + Name.Minute: plc.datetime.DatetimeComponent.MINUTE, + Name.Second: plc.datetime.DatetimeComponent.SECOND, + Name.Millisecond: plc.datetime.DatetimeComponent.MILLISECOND, + Name.Microsecond: plc.datetime.DatetimeComponent.MICROSECOND, + Name.Nanosecond: plc.datetime.DatetimeComponent.NANOSECOND, + } def __init__( self, dtype: plc.DataType, - name: TemporalFunctionName, + name: TemporalFunction.Name, options: tuple[Any, ...], *children: Expr, ) -> None: @@ -126,7 +126,7 @@ def do_evaluate( for child in self.children ] (column,) = columns - if self.name == TemporalFunctionName.Microsecond: + if self.name == TemporalFunction.Name.Microsecond: millis = plc.datetime.extract_datetime_component( column.obj, plc.datetime.DatetimeComponent.MILLISECOND ) @@ -146,7 +146,7 @@ def do_evaluate( plc.types.DataType(plc.types.TypeId.INT32), ) return Column(total_micros) - elif self.name == TemporalFunctionName.Nanosecond: + elif self.name == TemporalFunction.Name.Nanosecond: millis = plc.datetime.extract_datetime_component( column.obj, plc.datetime.DatetimeComponent.MILLISECOND ) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/string.py b/python/cudf_polars/cudf_polars/dsl/expressions/string.py index 4368fc40bd3..ebf7dfee18b 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/string.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/string.py @@ -6,7 +6,7 @@ from __future__ import annotations -from enum import Enum, auto +from enum import IntEnum, auto from typing import TYPE_CHECKING, Any import pyarrow as pa @@ -32,67 +32,69 @@ __all__ = ["StringFunction"] -class StringFunctionName(Enum): - Base64Decode = auto() - Base64Encode = auto() - ConcatHorizontal = auto() - ConcatVertical = auto() - Contains = auto() - ContainsMany = auto() - CountMatches = auto() - EndsWith = auto() - EscapeRegex = auto() - Extract = auto() - ExtractAll = auto() - ExtractGroups = auto() - Find = auto() - Head = auto() - HexDecode = auto() - HexEncode = auto() - JsonDecode = auto() - JsonPathMatch = auto() - LenBytes = auto() - LenChars = auto() - Lowercase = auto() - PadEnd = auto() - PadStart = auto() - Replace = auto() - ReplaceMany = auto() - Reverse = auto() - Slice = auto() - Split = auto() - SplitExact = auto() - SplitN = auto() - StartsWith = auto() - StripChars = auto() - StripCharsEnd = auto() - StripCharsStart = auto() - StripPrefix = auto() - StripSuffix = auto() - Strptime = auto() - Tail = auto() - Titlecase = auto() - ToDecimal = auto() - ToInteger = auto() - Uppercase = auto() - ZFill = auto() +class StringFunction(Expr): + class Name(IntEnum): + """Internal and picklable representation of polars' `StringFunction`.""" - @classmethod - def from_polars(cls, obj: pl_expr.StringFunction) -> Self: - function, name = str(obj).split(".", maxsplit=1) - if function != "StringFunction": - raise ValueError("StringFunction required") - return getattr(cls, name) + Base64Decode = auto() + Base64Encode = auto() + ConcatHorizontal = auto() + ConcatVertical = auto() + Contains = auto() + ContainsMany = auto() + CountMatches = auto() + EndsWith = auto() + EscapeRegex = auto() + Extract = auto() + ExtractAll = auto() + ExtractGroups = auto() + Find = auto() + Head = auto() + HexDecode = auto() + HexEncode = auto() + JsonDecode = auto() + JsonPathMatch = auto() + LenBytes = auto() + LenChars = auto() + Lowercase = auto() + PadEnd = auto() + PadStart = auto() + Replace = auto() + ReplaceMany = auto() + Reverse = auto() + Slice = auto() + Split = auto() + SplitExact = auto() + SplitN = auto() + StartsWith = auto() + StripChars = auto() + StripCharsEnd = auto() + StripCharsStart = auto() + StripPrefix = auto() + StripSuffix = auto() + Strptime = auto() + Tail = auto() + Titlecase = auto() + ToDecimal = auto() + ToInteger = auto() + Uppercase = auto() + ZFill = auto() + @classmethod + def from_polars(cls, obj: pl_expr.StringFunction) -> Self: + """Convert from polars' `StringFunction`.""" + function, name = str(obj).split(".", maxsplit=1) + if function != "StringFunction": + raise ValueError("StringFunction required") + return getattr(cls, name) -class StringFunction(Expr): __slots__ = ("name", "options", "_regex_program") _non_child = ("dtype", "name", "options") def __init__( self, dtype: plc.DataType, - name: StringFunctionName, + name: StringFunction.Name, options: tuple[Any, ...], *children: Expr, ) -> None: @@ -104,21 +106,21 @@ def __init__( def _validate_input(self): if self.name not in ( - StringFunctionName.Contains, - StringFunctionName.EndsWith, - StringFunctionName.Lowercase, - StringFunctionName.Replace, - StringFunctionName.ReplaceMany, - StringFunctionName.Slice, - StringFunctionName.Strptime, - StringFunctionName.StartsWith, - StringFunctionName.StripChars, - StringFunctionName.StripCharsStart, - StringFunctionName.StripCharsEnd, - StringFunctionName.Uppercase, + StringFunction.Name.Contains, + StringFunction.Name.EndsWith, + StringFunction.Name.Lowercase, + StringFunction.Name.Replace, + StringFunction.Name.ReplaceMany, + StringFunction.Name.Slice, + StringFunction.Name.Strptime, + StringFunction.Name.StartsWith, + StringFunction.Name.StripChars, + StringFunction.Name.StripCharsStart, + StringFunction.Name.StripCharsEnd, + StringFunction.Name.Uppercase, ): raise NotImplementedError(f"String function {self.name}") - if self.name == StringFunctionName.Contains: + if self.name == StringFunction.Name.Contains: literal, strict = self.options if not literal: if not strict: @@ -139,7 +141,7 @@ def _validate_input(self): raise NotImplementedError( f"Unsupported regex {pattern} for GPU engine." ) from e - elif self.name == StringFunctionName.Replace: + elif self.name == StringFunction.Name.Replace: _, literal = self.options if not literal: raise NotImplementedError("literal=False is not supported for replace") @@ -150,7 +152,7 @@ def _validate_input(self): raise NotImplementedError( "libcudf replace does not support empty strings" ) - elif self.name == StringFunctionName.ReplaceMany: + elif self.name == StringFunction.Name.ReplaceMany: (ascii_case_insensitive,) = self.options if ascii_case_insensitive: raise NotImplementedError( @@ -166,12 +168,12 @@ def _validate_input(self): "libcudf replace_many is implemented differently from polars " "for empty strings" ) - elif self.name == StringFunctionName.Slice: + elif self.name == StringFunction.Name.Slice: if not all(isinstance(child, Literal) for child in self.children[1:]): raise NotImplementedError( "Slice only supports literal start and stop values" ) - elif self.name == StringFunctionName.Strptime: + elif self.name == StringFunction.Name.Strptime: format, _, exact, cache = self.options if cache: raise NotImplementedError("Strptime cache is a CPU feature") @@ -180,9 +182,9 @@ def _validate_input(self): if not exact: raise NotImplementedError("Strptime does not support exact=False") elif self.name in { - StringFunctionName.StripChars, - StringFunctionName.StripCharsStart, - StringFunctionName.StripCharsEnd, + StringFunction.Name.StripChars, + StringFunction.Name.StripCharsStart, + StringFunction.Name.StripCharsEnd, }: if not isinstance(self.children[1], Literal): raise NotImplementedError( @@ -197,7 +199,7 @@ def do_evaluate( mapping: Mapping[Expr, Column] | None = None, ) -> Column: """Evaluate this expression given a dataframe for context.""" - if self.name == StringFunctionName.Contains: + if self.name == StringFunction.Name.Contains: child, arg = self.children column = child.evaluate(df, context=context, mapping=mapping) @@ -214,7 +216,7 @@ def do_evaluate( return Column( plc.strings.contains.contains_re(column.obj, self._regex_program) ) - elif self.name == StringFunctionName.Slice: + elif self.name == StringFunction.Name.Slice: child, expr_offset, expr_length = self.children assert isinstance(expr_offset, Literal) assert isinstance(expr_length, Literal) @@ -245,16 +247,16 @@ def do_evaluate( ) ) elif self.name in { - StringFunctionName.StripChars, - StringFunctionName.StripCharsStart, - StringFunctionName.StripCharsEnd, + StringFunction.Name.StripChars, + StringFunction.Name.StripCharsStart, + StringFunction.Name.StripCharsEnd, }: column, chars = ( c.evaluate(df, context=context, mapping=mapping) for c in self.children ) - if self.name == StringFunctionName.StripCharsStart: + if self.name == StringFunction.Name.StripCharsStart: side = plc.strings.SideType.LEFT - elif self.name == StringFunctionName.StripCharsEnd: + elif self.name == StringFunction.Name.StripCharsEnd: side = plc.strings.SideType.RIGHT else: side = plc.strings.SideType.BOTH @@ -264,13 +266,13 @@ def do_evaluate( child.evaluate(df, context=context, mapping=mapping) for child in self.children ] - if self.name == StringFunctionName.Lowercase: + if self.name == StringFunction.Name.Lowercase: (column,) = columns return Column(plc.strings.case.to_lower(column.obj)) - elif self.name == StringFunctionName.Uppercase: + elif self.name == StringFunction.Name.Uppercase: (column,) = columns return Column(plc.strings.case.to_upper(column.obj)) - elif self.name == StringFunctionName.EndsWith: + elif self.name == StringFunction.Name.EndsWith: column, suffix = columns return Column( plc.strings.find.ends_with( @@ -280,7 +282,7 @@ def do_evaluate( else suffix.obj, ) ) - elif self.name == StringFunctionName.StartsWith: + elif self.name == StringFunction.Name.StartsWith: column, prefix = columns return Column( plc.strings.find.starts_with( @@ -290,7 +292,7 @@ def do_evaluate( else prefix.obj, ) ) - elif self.name == StringFunctionName.Strptime: + elif self.name == StringFunction.Name.Strptime: # TODO: ignores ambiguous format, strict, exact, cache = self.options col = self.children[0].evaluate(df, context=context, mapping=mapping) @@ -322,7 +324,7 @@ def do_evaluate( res.columns()[0], self.dtype, format ) ) - elif self.name == StringFunctionName.Replace: + elif self.name == StringFunction.Name.Replace: column, target, repl = columns n, _ = self.options return Column( @@ -330,7 +332,7 @@ def do_evaluate( column.obj, target.obj_scalar, repl.obj_scalar, maxrepl=n ) ) - elif self.name == StringFunctionName.ReplaceMany: + elif self.name == StringFunction.Name.ReplaceMany: column, target, repl = columns return Column( plc.strings.replace.replace_multiple(column.obj, target.obj, repl.obj) diff --git a/python/cudf_polars/cudf_polars/dsl/to_ast.py b/python/cudf_polars/cudf_polars/dsl/to_ast.py index 4c9313dfecc..0ec54a52d43 100644 --- a/python/cudf_polars/cudf_polars/dsl/to_ast.py +++ b/python/cudf_polars/cudf_polars/dsl/to_ast.py @@ -12,7 +12,6 @@ from pylibcudf import expressions as plc_expr from cudf_polars.dsl import expr -from cudf_polars.dsl.expressions.boolean import BooleanFunctionName from cudf_polars.dsl.traversal import CachingVisitor, reuse_if_unchanged from cudf_polars.typing import GenericTransformer @@ -184,7 +183,7 @@ def _(node: expr.BinOp, self: Transformer) -> plc_expr.Expression: @_to_ast.register def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: - if node.name == BooleanFunctionName.IsIn: + if node.name == expr.BooleanFunction.Name.IsIn: needles, haystack = node.children if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16: # 16 is an arbitrary limit @@ -203,14 +202,14 @@ def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: raise NotImplementedError( f"Parquet filters don't support {node.name} on columns" ) - if node.name == BooleanFunctionName.IsNull: + if node.name == expr.BooleanFunction.Name.IsNull: return plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])) - elif node.name == BooleanFunctionName.IsNotNull: + elif node.name == expr.BooleanFunction.Name.IsNotNull: return plc_expr.Operation( plc_expr.ASTOperator.NOT, plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])), ) - elif node.name == BooleanFunctionName.Not: + elif node.name == expr.BooleanFunction.Name.Not: return plc_expr.Operation(plc_expr.ASTOperator.NOT, self(node.children[0])) raise NotImplementedError(f"AST conversion does not support {node.name}") diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 76b1f92fc0b..2727019ff7f 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -21,9 +21,6 @@ import pylibcudf as plc from cudf_polars.dsl import expr, ir -from cudf_polars.dsl.expressions.boolean import BooleanFunctionName -from cudf_polars.dsl.expressions.datetime import TemporalFunctionName -from cudf_polars.dsl.expressions.string import StringFunctionName from cudf_polars.dsl.to_ast import insert_colrefs from cudf_polars.typing import NodeTraverser from cudf_polars.utils import dtypes, sorting @@ -535,11 +532,15 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex pa.scalar("", type=plc.interop.to_arrow(column.dtype)), ) return expr.StringFunction( - dtype, StringFunctionName.from_polars(name), options, column, chars + dtype, + expr.StringFunction.Name.from_polars(name), + options, + column, + chars, ) return expr.StringFunction( dtype, - StringFunctionName.from_polars(name), + expr.StringFunction.Name.from_polars(name), options, *(translator.translate_expr(n=n) for n in node.input), ) @@ -556,7 +557,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex ) return expr.BooleanFunction( dtype, - BooleanFunctionName.from_polars(name), + expr.BooleanFunction.Name.from_polars(name), options, *(translator.translate_expr(n=n) for n in node.input), ) @@ -576,7 +577,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex } result_expr = expr.TemporalFunction( dtype, - TemporalFunctionName.from_polars(name), + expr.TemporalFunction.Name.from_polars(name), options, *(translator.translate_expr(n=n) for n in node.input), ) From 7d009f9578f554c51ddbac0b6120496956339042 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 22 Nov 2024 08:41:04 -0800 Subject: [PATCH 06/11] Use `is` for `IntEnum` equality checks --- .../cudf_polars/dsl/expressions/boolean.py | 30 ++++++++--------- .../cudf_polars/dsl/expressions/datetime.py | 4 +-- .../cudf_polars/dsl/expressions/string.py | 32 +++++++++---------- python/cudf_polars/cudf_polars/dsl/to_ast.py | 8 ++--- 4 files changed, 37 insertions(+), 37 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py index 9cccd4d97d7..b3146c0448f 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py @@ -77,7 +77,7 @@ def __init__( self.options = options self.name = name self.children = children - if self.name == BooleanFunction.Name.IsIn and not all( + if self.name is BooleanFunction.Name.IsIn and not all( c.dtype == self.children[0].dtype for c in self.children ): # TODO: If polars IR doesn't put the casts in, we need to @@ -146,7 +146,7 @@ def do_evaluate( ): # Avoid evaluating the child if the dtype tells us it's unnecessary. (child,) = self.children - is_finite = self.name == BooleanFunction.Name.IsFinite + is_finite = self.name is BooleanFunction.Name.IsFinite if child.dtype.id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64): value = plc.interop.from_arrow( pa.scalar(value=is_finite, type=plc.interop.to_arrow(self.dtype)) @@ -176,7 +176,7 @@ def do_evaluate( if self.name in (BooleanFunction.Name.Any, BooleanFunction.Name.All): (ignore_nulls,) = self.options (column,) = columns - is_any = self.name == BooleanFunction.Name.Any + is_any = self.name is BooleanFunction.Name.Any agg = plc.aggregation.any() if is_any else plc.aggregation.all() result = plc.reduce.reduce(column.obj, agg, self.dtype) if not ignore_nulls and column.obj.null_count() > 0: @@ -196,27 +196,27 @@ def do_evaluate( # False || Null => Null True && Null => Null return Column(plc.Column.all_null_like(column.obj, 1)) return Column(plc.Column.from_scalar(result, 1)) - if self.name == BooleanFunction.Name.IsNull: + if self.name is BooleanFunction.Name.IsNull: (column,) = columns return Column(plc.unary.is_null(column.obj)) - elif self.name == BooleanFunction.Name.IsNotNull: + elif self.name is BooleanFunction.Name.IsNotNull: (column,) = columns return Column(plc.unary.is_valid(column.obj)) - elif self.name == BooleanFunction.Name.IsNan: + elif self.name is BooleanFunction.Name.IsNan: (column,) = columns return Column( plc.unary.is_nan(column.obj).with_mask( column.obj.null_mask(), column.obj.null_count() ) ) - elif self.name == BooleanFunction.Name.IsNotNan: + elif self.name is BooleanFunction.Name.IsNotNan: (column,) = columns return Column( plc.unary.is_not_nan(column.obj).with_mask( column.obj.null_mask(), column.obj.null_count() ) ) - elif self.name == BooleanFunction.Name.IsFirstDistinct: + elif self.name is BooleanFunction.Name.IsFirstDistinct: (column,) = columns return self._distinct( column, @@ -228,7 +228,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == BooleanFunction.Name.IsLastDistinct: + elif self.name is BooleanFunction.Name.IsLastDistinct: (column,) = columns return self._distinct( column, @@ -240,7 +240,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == BooleanFunction.Name.IsUnique: + elif self.name is BooleanFunction.Name.IsUnique: (column,) = columns return self._distinct( column, @@ -252,7 +252,7 @@ def do_evaluate( pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == BooleanFunction.Name.IsDuplicated: + elif self.name is BooleanFunction.Name.IsDuplicated: (column,) = columns return self._distinct( column, @@ -264,7 +264,7 @@ def do_evaluate( pa.scalar(value=True, type=plc.interop.to_arrow(self.dtype)) ), ) - elif self.name == BooleanFunction.Name.AllHorizontal: + elif self.name is BooleanFunction.Name.AllHorizontal: return Column( reduce( partial( @@ -275,7 +275,7 @@ def do_evaluate( (c.obj for c in columns), ) ) - elif self.name == BooleanFunction.Name.AnyHorizontal: + elif self.name is BooleanFunction.Name.AnyHorizontal: return Column( reduce( partial( @@ -286,10 +286,10 @@ def do_evaluate( (c.obj for c in columns), ) ) - elif self.name == BooleanFunction.Name.IsIn: + elif self.name is BooleanFunction.Name.IsIn: needles, haystack = columns return Column(plc.search.contains(haystack.obj, needles.obj)) - elif self.name == BooleanFunction.Name.Not: + elif self.name is BooleanFunction.Name.Not: (column,) = columns return Column( plc.unary.unary_operation(column.obj, plc.unary.UnaryOperator.NOT) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py index d37318eb71b..ae559065318 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py @@ -126,7 +126,7 @@ def do_evaluate( for child in self.children ] (column,) = columns - if self.name == TemporalFunction.Name.Microsecond: + if self.name is TemporalFunction.Name.Microsecond: millis = plc.datetime.extract_datetime_component( column.obj, plc.datetime.DatetimeComponent.MILLISECOND ) @@ -146,7 +146,7 @@ def do_evaluate( plc.types.DataType(plc.types.TypeId.INT32), ) return Column(total_micros) - elif self.name == TemporalFunction.Name.Nanosecond: + elif self.name is TemporalFunction.Name.Nanosecond: millis = plc.datetime.extract_datetime_component( column.obj, plc.datetime.DatetimeComponent.MILLISECOND ) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/string.py b/python/cudf_polars/cudf_polars/dsl/expressions/string.py index ebf7dfee18b..03a6ee3b0e0 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/string.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/string.py @@ -120,7 +120,7 @@ def _validate_input(self): StringFunction.Name.Uppercase, ): raise NotImplementedError(f"String function {self.name}") - if self.name == StringFunction.Name.Contains: + if self.name is StringFunction.Name.Contains: literal, strict = self.options if not literal: if not strict: @@ -141,7 +141,7 @@ def _validate_input(self): raise NotImplementedError( f"Unsupported regex {pattern} for GPU engine." ) from e - elif self.name == StringFunction.Name.Replace: + elif self.name is StringFunction.Name.Replace: _, literal = self.options if not literal: raise NotImplementedError("literal=False is not supported for replace") @@ -152,7 +152,7 @@ def _validate_input(self): raise NotImplementedError( "libcudf replace does not support empty strings" ) - elif self.name == StringFunction.Name.ReplaceMany: + elif self.name is StringFunction.Name.ReplaceMany: (ascii_case_insensitive,) = self.options if ascii_case_insensitive: raise NotImplementedError( @@ -168,12 +168,12 @@ def _validate_input(self): "libcudf replace_many is implemented differently from polars " "for empty strings" ) - elif self.name == StringFunction.Name.Slice: + elif self.name is StringFunction.Name.Slice: if not all(isinstance(child, Literal) for child in self.children[1:]): raise NotImplementedError( "Slice only supports literal start and stop values" ) - elif self.name == StringFunction.Name.Strptime: + elif self.name is StringFunction.Name.Strptime: format, _, exact, cache = self.options if cache: raise NotImplementedError("Strptime cache is a CPU feature") @@ -199,7 +199,7 @@ def do_evaluate( mapping: Mapping[Expr, Column] | None = None, ) -> Column: """Evaluate this expression given a dataframe for context.""" - if self.name == StringFunction.Name.Contains: + if self.name is StringFunction.Name.Contains: child, arg = self.children column = child.evaluate(df, context=context, mapping=mapping) @@ -216,7 +216,7 @@ def do_evaluate( return Column( plc.strings.contains.contains_re(column.obj, self._regex_program) ) - elif self.name == StringFunction.Name.Slice: + elif self.name is StringFunction.Name.Slice: child, expr_offset, expr_length = self.children assert isinstance(expr_offset, Literal) assert isinstance(expr_length, Literal) @@ -254,9 +254,9 @@ def do_evaluate( column, chars = ( c.evaluate(df, context=context, mapping=mapping) for c in self.children ) - if self.name == StringFunction.Name.StripCharsStart: + if self.name is StringFunction.Name.StripCharsStart: side = plc.strings.SideType.LEFT - elif self.name == StringFunction.Name.StripCharsEnd: + elif self.name is StringFunction.Name.StripCharsEnd: side = plc.strings.SideType.RIGHT else: side = plc.strings.SideType.BOTH @@ -266,13 +266,13 @@ def do_evaluate( child.evaluate(df, context=context, mapping=mapping) for child in self.children ] - if self.name == StringFunction.Name.Lowercase: + if self.name is StringFunction.Name.Lowercase: (column,) = columns return Column(plc.strings.case.to_lower(column.obj)) - elif self.name == StringFunction.Name.Uppercase: + elif self.name is StringFunction.Name.Uppercase: (column,) = columns return Column(plc.strings.case.to_upper(column.obj)) - elif self.name == StringFunction.Name.EndsWith: + elif self.name is StringFunction.Name.EndsWith: column, suffix = columns return Column( plc.strings.find.ends_with( @@ -282,7 +282,7 @@ def do_evaluate( else suffix.obj, ) ) - elif self.name == StringFunction.Name.StartsWith: + elif self.name is StringFunction.Name.StartsWith: column, prefix = columns return Column( plc.strings.find.starts_with( @@ -292,7 +292,7 @@ def do_evaluate( else prefix.obj, ) ) - elif self.name == StringFunction.Name.Strptime: + elif self.name is StringFunction.Name.Strptime: # TODO: ignores ambiguous format, strict, exact, cache = self.options col = self.children[0].evaluate(df, context=context, mapping=mapping) @@ -324,7 +324,7 @@ def do_evaluate( res.columns()[0], self.dtype, format ) ) - elif self.name == StringFunction.Name.Replace: + elif self.name is StringFunction.Name.Replace: column, target, repl = columns n, _ = self.options return Column( @@ -332,7 +332,7 @@ def do_evaluate( column.obj, target.obj_scalar, repl.obj_scalar, maxrepl=n ) ) - elif self.name == StringFunction.Name.ReplaceMany: + elif self.name is StringFunction.Name.ReplaceMany: column, target, repl = columns return Column( plc.strings.replace.replace_multiple(column.obj, target.obj, repl.obj) diff --git a/python/cudf_polars/cudf_polars/dsl/to_ast.py b/python/cudf_polars/cudf_polars/dsl/to_ast.py index 0ec54a52d43..c3febc833e2 100644 --- a/python/cudf_polars/cudf_polars/dsl/to_ast.py +++ b/python/cudf_polars/cudf_polars/dsl/to_ast.py @@ -183,7 +183,7 @@ def _(node: expr.BinOp, self: Transformer) -> plc_expr.Expression: @_to_ast.register def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: - if node.name == expr.BooleanFunction.Name.IsIn: + if node.name is expr.BooleanFunction.Name.IsIn: needles, haystack = node.children if isinstance(haystack, expr.LiteralColumn) and len(haystack.value) < 16: # 16 is an arbitrary limit @@ -202,14 +202,14 @@ def _(node: expr.BooleanFunction, self: Transformer) -> plc_expr.Expression: raise NotImplementedError( f"Parquet filters don't support {node.name} on columns" ) - if node.name == expr.BooleanFunction.Name.IsNull: + if node.name is expr.BooleanFunction.Name.IsNull: return plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])) - elif node.name == expr.BooleanFunction.Name.IsNotNull: + elif node.name is expr.BooleanFunction.Name.IsNotNull: return plc_expr.Operation( plc_expr.ASTOperator.NOT, plc_expr.Operation(plc_expr.ASTOperator.IS_NULL, self(node.children[0])), ) - elif node.name == expr.BooleanFunction.Name.Not: + elif node.name is expr.BooleanFunction.Name.Not: return plc_expr.Operation(plc_expr.ASTOperator.NOT, self(node.children[0])) raise NotImplementedError(f"AST conversion does not support {node.name}") From 9b54437c927f326fcc4442876d63d0639af7f1e3 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 25 Nov 2024 01:15:59 -0800 Subject: [PATCH 07/11] Handle string unpacking failure --- python/cudf_polars/cudf_polars/dsl/expressions/boolean.py | 6 +++++- python/cudf_polars/cudf_polars/dsl/expressions/datetime.py | 6 +++++- python/cudf_polars/cudf_polars/dsl/expressions/string.py | 6 +++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py index b3146c0448f..1682e7a8a9c 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/boolean.py @@ -58,7 +58,11 @@ class Name(IntEnum): @classmethod def from_polars(cls, obj: pl_expr.BooleanFunction) -> Self: """Convert from polars' `BooleanFunction`.""" - function, name = str(obj).split(".", maxsplit=1) + try: + function, name = str(obj).split(".", maxsplit=1) + except ValueError: + # Failed to unpack string + function = None if function != "BooleanFunction": raise ValueError("BooleanFunction required") return getattr(cls, name) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py index ae559065318..c2dddfd9940 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/datetime.py @@ -79,7 +79,11 @@ class Name(IntEnum): @classmethod def from_polars(cls, obj: pl_expr.TemporalFunction) -> Self: """Convert from polars' `TemporalFunction`.""" - function, name = str(obj).split(".", maxsplit=1) + try: + function, name = str(obj).split(".", maxsplit=1) + except ValueError: + # Failed to unpack string + function = None if function != "TemporalFunction": raise ValueError("TemporalFunction required") return getattr(cls, name) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/string.py b/python/cudf_polars/cudf_polars/dsl/expressions/string.py index 03a6ee3b0e0..92c3c658c21 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/string.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/string.py @@ -83,7 +83,11 @@ class Name(IntEnum): @classmethod def from_polars(cls, obj: pl_expr.StringFunction) -> Self: """Convert from polars' `StringFunction`.""" - function, name = str(obj).split(".", maxsplit=1) + try: + function, name = str(obj).split(".", maxsplit=1) + except ValueError: + # Failed to unpack string + function = None if function != "StringFunction": raise ValueError("StringFunction required") return getattr(cls, name) From 5a04207cd657cb7782ab92dff620309e06d73199 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 25 Nov 2024 01:20:31 -0800 Subject: [PATCH 08/11] Add basic serialization tests --- .../tests/dsl/test_serialization.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 python/cudf_polars/tests/dsl/test_serialization.py diff --git a/python/cudf_polars/tests/dsl/test_serialization.py b/python/cudf_polars/tests/dsl/test_serialization.py new file mode 100644 index 00000000000..f2b9480880f --- /dev/null +++ b/python/cudf_polars/tests/dsl/test_serialization.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pickle + +import pytest + +from polars.polars import _expr_nodes as pl_expr + +from cudf_polars.dsl.expressions.boolean import BooleanFunction +from cudf_polars.dsl.expressions.datetime import TemporalFunction +from cudf_polars.dsl.expressions.string import StringFunction + + +@pytest.mark.parametrize( + "function", [BooleanFunction, TemporalFunction, StringFunction] +) +def test_function_name_serialization_all_values(function): + # Test serialization and deserialization for all values of function.Name + for name in function.Name: + serialized_name = pickle.dumps(name) + deserialized_name = pickle.loads(serialized_name) + assert deserialized_name is name + + +@pytest.mark.parametrize( + "function", [BooleanFunction, TemporalFunction, StringFunction] +) +def test_function_name_invalid(function): + # Test invalid attribute name + with pytest.raises( + AttributeError, match="type object 'Name' has no attribute 'InvalidAttribute'" + ): + assert function.Name.InvalidAttribute is function.Name.InvalidAttribute + + +@pytest.mark.parametrize( + "function", [BooleanFunction, TemporalFunction, StringFunction] +) +def test_from_polars_all_names(function): + # Test that all valid names of polars expressions are correctly converted + for name in function.Name: + polars_function = getattr(pl_expr, function.__name__) + polars_function_attr = getattr(polars_function, name.name) + cudf_function = function.Name.from_polars(polars_function_attr) + assert cudf_function == name + + +@pytest.mark.parametrize( + "function", [BooleanFunction, TemporalFunction, StringFunction] +) +def test_from_polars_invalid_attribute(function): + # Test converting from invalid attribute name + with pytest.raises(ValueError, match=f"{function.__name__} required"): + function.Name.from_polars("InvalidAttribute") + + +@pytest.mark.parametrize( + "function", [BooleanFunction, TemporalFunction, StringFunction] +) +def test_from_polars_invalid_polars_attribute(function): + # Test converting from polars function with invalid attribute name + with pytest.raises( + AttributeError, match="type object 'Name' has no attribute 'InvalidAttribute'" + ): + function.Name.from_polars(f"{function.__name__}.InvalidAttribute") From fcf820f0d39f046ce894bd1fc0876408aedb6177 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 25 Nov 2024 04:25:06 -0800 Subject: [PATCH 09/11] Replace pytest parametrization with a fixture --- .../tests/dsl/test_serialization.py | 20 +++++-------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/python/cudf_polars/tests/dsl/test_serialization.py b/python/cudf_polars/tests/dsl/test_serialization.py index f2b9480880f..3bfe180d21f 100644 --- a/python/cudf_polars/tests/dsl/test_serialization.py +++ b/python/cudf_polars/tests/dsl/test_serialization.py @@ -14,9 +14,11 @@ from cudf_polars.dsl.expressions.string import StringFunction -@pytest.mark.parametrize( - "function", [BooleanFunction, TemporalFunction, StringFunction] -) +@pytest.fixture(params=[BooleanFunction, StringFunction, TemporalFunction]) +def function(request): + return request.param + + def test_function_name_serialization_all_values(function): # Test serialization and deserialization for all values of function.Name for name in function.Name: @@ -25,9 +27,6 @@ def test_function_name_serialization_all_values(function): assert deserialized_name is name -@pytest.mark.parametrize( - "function", [BooleanFunction, TemporalFunction, StringFunction] -) def test_function_name_invalid(function): # Test invalid attribute name with pytest.raises( @@ -36,9 +35,6 @@ def test_function_name_invalid(function): assert function.Name.InvalidAttribute is function.Name.InvalidAttribute -@pytest.mark.parametrize( - "function", [BooleanFunction, TemporalFunction, StringFunction] -) def test_from_polars_all_names(function): # Test that all valid names of polars expressions are correctly converted for name in function.Name: @@ -48,18 +44,12 @@ def test_from_polars_all_names(function): assert cudf_function == name -@pytest.mark.parametrize( - "function", [BooleanFunction, TemporalFunction, StringFunction] -) def test_from_polars_invalid_attribute(function): # Test converting from invalid attribute name with pytest.raises(ValueError, match=f"{function.__name__} required"): function.Name.from_polars("InvalidAttribute") -@pytest.mark.parametrize( - "function", [BooleanFunction, TemporalFunction, StringFunction] -) def test_from_polars_invalid_polars_attribute(function): # Test converting from polars function with invalid attribute name with pytest.raises( From 4de3b79e92477d8595d28d07155203d0ea7f61db Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 25 Nov 2024 04:30:22 -0800 Subject: [PATCH 10/11] Rewrite test to check for all polars attributes --- python/cudf_polars/tests/dsl/test_serialization.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/cudf_polars/tests/dsl/test_serialization.py b/python/cudf_polars/tests/dsl/test_serialization.py index 3bfe180d21f..4a170d9fbbd 100644 --- a/python/cudf_polars/tests/dsl/test_serialization.py +++ b/python/cudf_polars/tests/dsl/test_serialization.py @@ -37,11 +37,13 @@ def test_function_name_invalid(function): def test_from_polars_all_names(function): # Test that all valid names of polars expressions are correctly converted + polars_function = getattr(pl_expr, function.__name__) + polars_names = [name for name in dir(polars_function) if not name.startswith("_")] + # Check names advertised by polars are the same as we advertise + assert set(polars_names) == set(function.Name.__members__) for name in function.Name: - polars_function = getattr(pl_expr, function.__name__) - polars_function_attr = getattr(polars_function, name.name) - cudf_function = function.Name.from_polars(polars_function_attr) - assert cudf_function == name + attr = getattr(polars_function, name.name) + assert function.Name.from_polars(attr) == name def test_from_polars_invalid_attribute(function): From 17a1a722664c113d1a0b1f6f19a0956bac0eb7e8 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 25 Nov 2024 10:01:12 -0800 Subject: [PATCH 11/11] Fix regex to be matched in tests --- python/cudf_polars/tests/dsl/test_serialization.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/cudf_polars/tests/dsl/test_serialization.py b/python/cudf_polars/tests/dsl/test_serialization.py index 4a170d9fbbd..7de8f959843 100644 --- a/python/cudf_polars/tests/dsl/test_serialization.py +++ b/python/cudf_polars/tests/dsl/test_serialization.py @@ -29,9 +29,7 @@ def test_function_name_serialization_all_values(function): def test_function_name_invalid(function): # Test invalid attribute name - with pytest.raises( - AttributeError, match="type object 'Name' has no attribute 'InvalidAttribute'" - ): + with pytest.raises(AttributeError, match="InvalidAttribute"): assert function.Name.InvalidAttribute is function.Name.InvalidAttribute @@ -54,7 +52,5 @@ def test_from_polars_invalid_attribute(function): def test_from_polars_invalid_polars_attribute(function): # Test converting from polars function with invalid attribute name - with pytest.raises( - AttributeError, match="type object 'Name' has no attribute 'InvalidAttribute'" - ): + with pytest.raises(AttributeError, match="InvalidAttribute"): function.Name.from_polars(f"{function.__name__}.InvalidAttribute")