Skip to content

Commit

Permalink
Implement handlers for boolean functions
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Apr 24, 2024
1 parent d44f0ed commit b6d1210
Showing 1 changed file with 169 additions and 3 deletions.
172 changes: 169 additions & 3 deletions python/cudf_polars/cudf_polars/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,173 @@ def evaluate_expr(
raise AssertionError(f"Unhandled expression type {type(expr)}")


BOOLEAN_FUNCTIONS = {
"is_null",
"is_not_null",
"any",
"all",
"is_finite",
"is_infinite",
"is_nan",
"is_not_nan",
"is_first_distinct",
"is_last_distinct",
"is_unique",
"is_duplicated",
"is_between",
"is_in",
"all_horizontal",
"any_horizontal",
"not",
}


def boolean_function(
name: str, arguments: list[ColumnType], options
) -> ColumnType:
"""
Apply a function returning a boolean column to some arguments.
Parameters
----------
name
Name of the function to apply
arguments
List of columns to apply to
options
Any options for the function
Returns
-------
New column.
"""
if name == "is_null":
return plc.unary.is_null(*arguments)
elif name == "is_not_null":
return plc.unary.is_valid(*arguments)
elif name == "any":
(ignore_nulls,) = options
if not ignore_nulls:
raise NotImplementedError("Kleene logic for any")
(column,) = arguments
return plc.Column.from_scalar(
plc.reduce.reduce(column, plc.aggregation.any(), column.type()), 1
)
elif name == "all":
(ignore_nulls,) = options
if not ignore_nulls:
raise NotImplementedError("Kleene logic for all")
(column,) = arguments
return plc.Column.from_scalar(
plc.reduce.reduce(column, plc.aggregation.all(), column.type()), 1
)
elif name == "is_finite":
raise NotImplementedError("is_finite")
elif name == "is_infinite":
raise NotImplementedError("is_infinite")
elif name == "is_nan":
# TODO: polars is_nan returns NULL for null inputs
# so need to carry over null mask, or add argument
return plc.unary.is_nan(*arguments)
elif name == "is_not_nan":
# TODO: https://github.com/pola-rs/polars/issues/15862
return plc.unary.is_not_nan(*arguments)
elif name == "is_first_distinct":
input = plc.Table(arguments)
distinct_indices = plc.stream_compaction.distinct_indices(
input,
plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
plc.types.NullEquality.EQUAL,
plc.types.NanEquality.ALL_EQUAL,
)
result = plc.copying.scatter(
[plc.interop.from_arrow(pa.scalar(True))],
distinct_indices,
plc.Table(
[
plc.Column.from_scalar(
plc.interop.from_arrow(pa.scalar(False)),
input.num_rows(),
)
]
),
)
(mask,) = result.columns()
return mask
elif name == "is_last_distinct":
input = plc.Table(arguments)
distinct_indices = plc.stream_compaction.distinct_indices(
input,
plc.stream_compaction.DuplicateKeepOption.KEEP_LAST,
plc.types.NullEquality.EQUAL,
plc.types.NanEquality.ALL_EQUAL,
)
result = plc.copying.scatter(
[plc.interop.from_arrow(pa.scalar(True))],
distinct_indices,
plc.Table(
[
plc.Column.from_scalar(
plc.interop.from_arrow(pa.scalar(False)),
input.num_rows(),
)
]
),
)
(mask,) = result.columns()
return mask
elif name == "is_unique":
input = plc.Table(arguments)
distinct_indices = plc.stream_compaction.distinct_indices(
input,
plc.stream_compaction.DuplicateKeepOption.KEEP_NONE,
plc.types.NullEquality.EQUAL,
plc.types.NanEquality.ALL_EQUAL,
)
result = plc.copying.scatter(
[plc.interop.from_arrow(pa.scalar(True))],
distinct_indices,
plc.Table(
[
plc.Column.from_scalar(
plc.interop.from_arrow(pa.scalar(False)),
input.num_rows(),
)
]
),
)
(mask,) = result.columns()
return mask
elif name == "is_duplicated":
input = plc.Table(arguments)
distinct_indices = plc.stream_compaction.distinct_indices(
input,
plc.stream_compaction.DuplicateKeepOption.KEEP_NONE,
plc.types.NullEquality.EQUAL,
plc.types.NanEquality.ALL_EQUAL,
)
result = plc.copying.scatter(
[plc.interop.from_arrow(pa.scalar(False))],
distinct_indices,
plc.Table(
[
plc.Column.from_scalar(
plc.interop.from_arrow(pa.scalar(True)),
input.num_rows(),
)
]
),
)
(mask,) = result.columns()
return mask
elif name == "not":
return plc.unary.unary_operation(
*arguments, plc.unary.UnaryOpterator.NOT
)
else:
raise NotImplementedError(f"unary boolean function {name}")


@evaluate_expr.register
def _expr_function(
expr: expr_nodes.Function, context: DataFrame, visitor: ExprVisitor
Expand Down Expand Up @@ -148,9 +315,8 @@ def _expr_function(
# return data.set_sorted(
# {name: getattr(DataFrame.IsSorted, flag.upper())}
# )
elif fname == "is_not_null":
(column,) = arguments
return plc.unary.is_valid(column)
elif fname in BOOLEAN_FUNCTIONS:
return boolean_function(fname, arguments, fargs)
else:
raise NotImplementedError(f"Function expression {fname=}")

Expand Down

0 comments on commit b6d1210

Please sign in to comment.