Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Scalar container type from polars interpreter #15953

Merged
merged 14 commits into from
Jun 11, 2024
Merged
8 changes: 7 additions & 1 deletion python/cudf_polars/cudf_polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@

from __future__ import annotations

from cudf_polars._version import __git_commit__, __version__
from cudf_polars.callback import execute_with_cudf
from cudf_polars.dsl.translate import translate_ir

__all__: list[str] = ["execute_with_cudf", "translate_ir"]
__all__: list[str] = [
"execute_with_cudf",
"translate_ir",
"__git_commit__",
"__version__",
wence- marked this conversation as resolved.
Show resolved Hide resolved
]
3 changes: 1 addition & 2 deletions python/cudf_polars/cudf_polars/containers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from __future__ import annotations

__all__: list[str] = ["DataFrame", "Column", "NamedColumn", "Scalar"]
__all__: list[str] = ["DataFrame", "Column", "NamedColumn"]

from cudf_polars.containers.column import Column, NamedColumn
from cudf_polars.containers.dataframe import DataFrame
from cudf_polars.containers.scalar import Scalar
28 changes: 27 additions & 1 deletion python/cudf_polars/cudf_polars/containers/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@


class Column:
"""A column with sortedness metadata."""
"""An immutable column with sortedness metadata."""

obj: plc.Column
is_sorted: plc.types.Sorted
order: plc.types.Order
null_order: plc.types.NullOrder
is_scalar: bool

def __init__(
self,
Expand All @@ -33,10 +34,33 @@ def __init__(
null_order: plc.types.NullOrder = plc.types.NullOrder.BEFORE,
):
self.obj = column
self.is_scalar = self.obj.size() == 1
if self.obj.size() <= 1:
is_sorted = plc.types.Sorted.YES
vyasr marked this conversation as resolved.
Show resolved Hide resolved
self.is_sorted = is_sorted
self.order = order
self.null_order = null_order

@functools.cached_property
def obj_scalar(self) -> plc.Scalar:
"""
A copy of the column object as a pylibcudf Scalar.

Returns
-------
pylibcudf Scalar object.

Raises
------
ValueError
If the column is not length-1.
"""
if not self.is_scalar:
raise ValueError(
f"Cannot convert a column of length {self.obj.size()} to scalar"
)
return plc.copying.get_element(self.obj, 0)
wence- marked this conversation as resolved.
Show resolved Hide resolved

def sorted_like(self, like: Column, /) -> Self:
"""
Copy sortedness properties from a column onto self.
Expand Down Expand Up @@ -81,6 +105,8 @@ def set_sorted(
-------
Self with metadata set.
"""
if self.obj.size() <= 1:
is_sorted = plc.types.Sorted.YES
self.is_sorted = is_sorted
self.order = order
self.null_order = null_order
Expand Down
6 changes: 2 additions & 4 deletions python/cudf_polars/cudf_polars/containers/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DataFrame:
"""A representation of a dataframe."""

columns: list[NamedColumn]
table: plc.Table | None
table: plc.Table

def __init__(self, columns: Sequence[NamedColumn]) -> None:
self.columns = list(columns)
Expand All @@ -41,7 +41,7 @@ def __init__(self, columns: Sequence[NamedColumn]) -> None:

def copy(self) -> Self:
"""Return a shallow copy of self."""
return type(self)(self.columns)
return type(self)([c.copy() for c in self.columns])
wence- marked this conversation as resolved.
Show resolved Hide resolved

def to_polars(self) -> pl.DataFrame:
"""Convert to a polars DataFrame."""
Expand Down Expand Up @@ -70,8 +70,6 @@ def num_columns(self) -> int:
@cached_property
def num_rows(self) -> int:
"""Number of rows."""
if self.table is None:
raise ValueError("Number of rows of frame with scalars makes no sense")
return self.table.num_rows()

@classmethod
Expand Down
23 changes: 0 additions & 23 deletions python/cudf_polars/cudf_polars/containers/scalar.py

This file was deleted.

114 changes: 69 additions & 45 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
DSL nodes for the polars expression language.

An expression node is a function, `DataFrame -> Column` or `DataFrame -> Scalar`.
An expression node is a function, `DataFrame -> Column`.

The evaluation context is provided by a LogicalPlan node, and can
affect the evaluation rule as well as providing the dataframe input.
Expand All @@ -26,7 +26,7 @@

import cudf._lib.pylibcudf as plc

from cudf_polars.containers import Column, NamedColumn, Scalar
from cudf_polars.containers import Column, NamedColumn
from cudf_polars.utils import sorting

if TYPE_CHECKING:
Expand Down Expand Up @@ -165,7 +165,7 @@ def do_evaluate(
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column: # TODO: return type is a lie for Literal
) -> Column:
"""
Evaluate this expression given a dataframe for context.

Expand All @@ -187,8 +187,7 @@ def do_evaluate(

Returns
-------
Column representing the evaluation of the expression (or maybe
a scalar).
Column representing the evaluation of the expression.

Raises
------
Expand All @@ -205,7 +204,7 @@ def evaluate(
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column: # TODO: return type is a lie for Literal
) -> Column:
"""
Evaluate this expression given a dataframe for context.

Expand All @@ -222,23 +221,13 @@ def evaluate(

Notes
-----
Individual subclasses should implement :meth:`do_allocate`,
Individual subclasses should implement :meth:`do_evaluate`,
wence- marked this conversation as resolved.
Show resolved Hide resolved
this method provides logic to handle lookups in the
substitution mapping.

The typed return value of :class:`Column` is not true when
evaluating :class:`Literal` nodes (which instead produce
:class:`Scalar` objects). However, these duck-type to having a
pylibcudf container object inside them, and usually they end
up appearing in binary expressions which pylibcudf handles
appropriately since there are overloads for (column, scalar)
pairs. We don't have to handle (scalar, scalar) in binops
since the polars optimizer has a constant-folding pass.
wence- marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
Column representing the evaluation of the expression (or maybe
a scalar).
Column representing the evaluation of the expression.

Raises
------
Expand Down Expand Up @@ -319,24 +308,35 @@ def evaluate(
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> NamedColumn:
"""Evaluate this expression given a dataframe for context."""
"""
Evaluate this expression given a dataframe for context.

Parameters
----------
df
DataFrame providing context
context
Execution context
mapping
Substitution mapping

Returns
-------
NamedColumn attaching a name to an evaluated Column

See Also
--------
:meth:`Expr.evaluate` for details, this function just adds the
name to a column produced from an expression.
"""
obj = self.value.evaluate(df, context=context, mapping=mapping)
if isinstance(obj, Scalar):
return NamedColumn(
plc.Column.from_scalar(obj.obj, 1),
self.name,
is_sorted=plc.types.Sorted.YES,
order=plc.types.Order.ASCENDING,
null_order=plc.types.NullOrder.BEFORE,
)
else:
return NamedColumn(
obj.obj,
self.name,
is_sorted=obj.is_sorted,
order=obj.order,
null_order=obj.null_order,
)
return NamedColumn(
obj.obj,
self.name,
is_sorted=obj.is_sorted,
order=obj.order,
null_order=obj.null_order,
)

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
Expand All @@ -363,7 +363,7 @@ def do_evaluate(
) -> Column:
"""Evaluate this expression given a dataframe for context."""
# datatype of pyarrow scalar is correct by construction.
return Scalar(plc.interop.from_arrow(self.value)) # type: ignore
return Column(plc.Column.from_scalar(plc.interop.from_arrow(self.value), 1))


class Col(Expr):
Expand Down Expand Up @@ -402,8 +402,14 @@ def do_evaluate(
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
# TODO: type is wrong, and dtype
return df.num_rows # type: ignore
return Column(
plc.Column.from_scalar(
plc.interop.from_arrow(
wence- marked this conversation as resolved.
Show resolved Hide resolved
pa.scalar(df.num_rows, type=plc.interop.to_arrow(self.dtype))
),
1,
)
)

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
Expand Down Expand Up @@ -664,10 +670,24 @@ def do_evaluate(
return Column(plc.strings.case.to_upper(column.obj))
elif self.name == pl_expr.StringFunction.EndsWith:
column, suffix = columns
return Column(plc.strings.find.ends_with(column.obj, suffix.obj))
return Column(
plc.strings.find.ends_with(
column.obj,
suffix.obj_scalar
if column.obj.size() != suffix.obj.size() and suffix.is_scalar
else suffix.obj,
)
)
elif self.name == pl_expr.StringFunction.StartsWith:
column, suffix = columns
return Column(plc.strings.find.starts_with(column.obj, suffix.obj))
column, prefix = columns
return Column(
plc.strings.find.starts_with(
column.obj,
prefix.obj_scalar
if column.obj.size() != prefix.obj.size() and prefix.is_scalar
else prefix.obj,
)
)
else:
raise NotImplementedError(f"StringFunction {self.name}")

Expand Down Expand Up @@ -875,9 +895,6 @@ def __init__(
self, dtype: plc.DataType, name: str, options: Any, value: Expr
) -> None:
super().__init__(dtype)
# TODO: fix polars name
if name == "nunique":
name = "n_unique"
self.name = name
self.options = options
self.children = (value,)
Expand Down Expand Up @@ -1092,8 +1109,15 @@ def do_evaluate(
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
lop = left.obj
rop = right.obj
if left.obj.size() != right.obj.size():
if left.is_scalar:
lop = left.obj_scalar
elif right.is_scalar:
rop = right.obj_scalar
return Column(
plc.binaryop.binary_operation(left.obj, right.obj, self.op, self.dtype),
plc.binaryop.binary_operation(lop, rop, self.op, self.dtype),
)

def collect_agg(self, *, depth: int) -> AggInfo:
Expand Down
Loading
Loading