-
Notifications
You must be signed in to change notification settings - Fork 919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Abstract polars function expression nodes to ensure they are serializable #17418
Changes from 3 commits
6ea021c
9f2e5c8
d142f50
be7fa52
3f69d98
7d009f9
5675ef9
b7d2bf1
9b54437
5a04207
07ba6fe
fcf820f
4de3b79
17a1a72
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,13 +6,12 @@ | |
|
||
from __future__ import annotations | ||
|
||
from enum import Enum, auto | ||
from functools import partial, reduce | ||
from typing import TYPE_CHECKING, Any, ClassVar | ||
|
||
import pyarrow as pa | ||
|
||
from polars.polars import _expr_nodes as pl_expr | ||
|
||
import pylibcudf as plc | ||
|
||
from cudf_polars.containers import Column | ||
|
@@ -31,22 +30,49 @@ | |
__all__ = ["BooleanFunction"] | ||
|
||
|
||
class BooleanFunctionName(Enum): | ||
All = auto() | ||
AllHorizontal = auto() | ||
Any = auto() | ||
AnyHorizontal = auto() | ||
IsBetween = auto() | ||
IsDuplicated = auto() | ||
IsFinite = auto() | ||
IsFirstDistinct = auto() | ||
IsIn = auto() | ||
IsInfinite = auto() | ||
IsLastDistinct = auto() | ||
IsNan = auto() | ||
IsNotNan = auto() | ||
IsNotNull = auto() | ||
IsNull = auto() | ||
IsUnique = auto() | ||
Not = auto() | ||
|
||
@staticmethod | ||
def get_polars_type(tp: BooleanFunctionName): | ||
function, name = str(tp).split(".") | ||
if function != "BooleanFunction": | ||
raise ValueError("BooleanFunction required") | ||
return getattr(BooleanFunctionName, name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this doesn't return a polars object, but one of ours. Also, because we're calling getattr, that suggests we should make this a classmethod
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I totally mix everything up here, thanks for catching. Done in be7fa52 . |
||
|
||
|
||
class BooleanFunction(Expr): | ||
__slots__ = ("name", "options") | ||
_non_child = ("dtype", "name", "options") | ||
|
||
def __init__( | ||
self, | ||
dtype: plc.DataType, | ||
name: pl_expr.BooleanFunction, | ||
name: BooleanFunctionName, | ||
options: tuple[Any, ...], | ||
*children: Expr, | ||
) -> None: | ||
self.dtype = dtype | ||
self.options = options | ||
self.name = name | ||
self.children = children | ||
if self.name == pl_expr.BooleanFunction.IsIn and not all( | ||
if self.name == BooleanFunctionName.IsIn and not all( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Because we have enums, we should now use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in 7d009f9 . |
||
c.dtype == self.children[0].dtype for c in self.children | ||
): | ||
# TODO: If polars IR doesn't put the casts in, we need to | ||
|
@@ -110,12 +136,12 @@ def do_evaluate( | |
) -> Column: | ||
"""Evaluate this expression given a dataframe for context.""" | ||
if self.name in ( | ||
pl_expr.BooleanFunction.IsFinite, | ||
pl_expr.BooleanFunction.IsInfinite, | ||
BooleanFunctionName.IsFinite, | ||
BooleanFunctionName.IsInfinite, | ||
): | ||
# Avoid evaluating the child if the dtype tells us it's unnecessary. | ||
(child,) = self.children | ||
is_finite = self.name == pl_expr.BooleanFunction.IsFinite | ||
is_finite = self.name == BooleanFunctionName.IsFinite | ||
if child.dtype.id() not in (plc.TypeId.FLOAT32, plc.TypeId.FLOAT64): | ||
value = plc.interop.from_arrow( | ||
pa.scalar(value=is_finite, type=plc.interop.to_arrow(self.dtype)) | ||
|
@@ -142,10 +168,10 @@ def do_evaluate( | |
] | ||
# Kleene logic for Any (OR) and All (AND) if ignore_nulls is | ||
# False | ||
if self.name in (pl_expr.BooleanFunction.Any, pl_expr.BooleanFunction.All): | ||
if self.name in (BooleanFunctionName.Any, BooleanFunctionName.All): | ||
(ignore_nulls,) = self.options | ||
(column,) = columns | ||
is_any = self.name == pl_expr.BooleanFunction.Any | ||
is_any = self.name == BooleanFunctionName.Any | ||
agg = plc.aggregation.any() if is_any else plc.aggregation.all() | ||
result = plc.reduce.reduce(column.obj, agg, self.dtype) | ||
if not ignore_nulls and column.obj.null_count() > 0: | ||
|
@@ -165,27 +191,27 @@ def do_evaluate( | |
# False || Null => Null True && Null => Null | ||
return Column(plc.Column.all_null_like(column.obj, 1)) | ||
return Column(plc.Column.from_scalar(result, 1)) | ||
if self.name == pl_expr.BooleanFunction.IsNull: | ||
if self.name == BooleanFunctionName.IsNull: | ||
(column,) = columns | ||
return Column(plc.unary.is_null(column.obj)) | ||
elif self.name == pl_expr.BooleanFunction.IsNotNull: | ||
elif self.name == BooleanFunctionName.IsNotNull: | ||
(column,) = columns | ||
return Column(plc.unary.is_valid(column.obj)) | ||
elif self.name == pl_expr.BooleanFunction.IsNan: | ||
elif self.name == BooleanFunctionName.IsNan: | ||
(column,) = columns | ||
return Column( | ||
plc.unary.is_nan(column.obj).with_mask( | ||
column.obj.null_mask(), column.obj.null_count() | ||
) | ||
) | ||
elif self.name == pl_expr.BooleanFunction.IsNotNan: | ||
elif self.name == BooleanFunctionName.IsNotNan: | ||
(column,) = columns | ||
return Column( | ||
plc.unary.is_not_nan(column.obj).with_mask( | ||
column.obj.null_mask(), column.obj.null_count() | ||
) | ||
) | ||
elif self.name == pl_expr.BooleanFunction.IsFirstDistinct: | ||
elif self.name == BooleanFunctionName.IsFirstDistinct: | ||
(column,) = columns | ||
return self._distinct( | ||
column, | ||
|
@@ -197,7 +223,7 @@ def do_evaluate( | |
pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) | ||
), | ||
) | ||
elif self.name == pl_expr.BooleanFunction.IsLastDistinct: | ||
elif self.name == BooleanFunctionName.IsLastDistinct: | ||
(column,) = columns | ||
return self._distinct( | ||
column, | ||
|
@@ -209,7 +235,7 @@ def do_evaluate( | |
pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) | ||
), | ||
) | ||
elif self.name == pl_expr.BooleanFunction.IsUnique: | ||
elif self.name == BooleanFunctionName.IsUnique: | ||
(column,) = columns | ||
return self._distinct( | ||
column, | ||
|
@@ -221,7 +247,7 @@ def do_evaluate( | |
pa.scalar(value=False, type=plc.interop.to_arrow(self.dtype)) | ||
), | ||
) | ||
elif self.name == pl_expr.BooleanFunction.IsDuplicated: | ||
elif self.name == BooleanFunctionName.IsDuplicated: | ||
(column,) = columns | ||
return self._distinct( | ||
column, | ||
|
@@ -233,7 +259,7 @@ def do_evaluate( | |
pa.scalar(value=True, type=plc.interop.to_arrow(self.dtype)) | ||
), | ||
) | ||
elif self.name == pl_expr.BooleanFunction.AllHorizontal: | ||
elif self.name == BooleanFunctionName.AllHorizontal: | ||
return Column( | ||
reduce( | ||
partial( | ||
|
@@ -244,7 +270,7 @@ def do_evaluate( | |
(c.obj for c in columns), | ||
) | ||
) | ||
elif self.name == pl_expr.BooleanFunction.AnyHorizontal: | ||
elif self.name == BooleanFunctionName.AnyHorizontal: | ||
return Column( | ||
reduce( | ||
partial( | ||
|
@@ -255,10 +281,10 @@ def do_evaluate( | |
(c.obj for c in columns), | ||
) | ||
) | ||
elif self.name == pl_expr.BooleanFunction.IsIn: | ||
elif self.name == BooleanFunctionName.IsIn: | ||
needles, haystack = columns | ||
return Column(plc.search.contains(haystack.obj, needles.obj)) | ||
elif self.name == pl_expr.BooleanFunction.Not: | ||
elif self.name == BooleanFunctionName.Not: | ||
(column,) = columns | ||
return Column( | ||
plc.unary.unary_operation(column.obj, plc.unary.UnaryOperator.NOT) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: use
IntEnum
.Also, let's nest this class inside the
BooleanFunction
class below, and just call itName
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 3f69d98 .