Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more complete type annotations in polars interpreter #15942

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ repos:
- id: rapids-dependency-file-generator
args: ["--clean"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.3
rev: v0.4.8
hooks:
- id: ruff
files: python/.*$
Expand Down
5 changes: 4 additions & 1 deletion python/cudf_polars/cudf_polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@

from __future__ import annotations

__all__: list[str] = []
from cudf_polars.callback import execute_with_cudf
from cudf_polars.dsl.translate import translate_ir

__all__: list[str] = ["execute_with_cudf", "translate_ir"]
3 changes: 2 additions & 1 deletion python/cudf_polars/cudf_polars/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import polars as pl

from cudf_polars.dsl.ir import IR
from cudf_polars.typing import NodeTraverser

__all__: list[str] = ["execute_with_cudf"]

Expand All @@ -33,7 +34,7 @@ def _callback(
return ir.evaluate(cache={}).to_polars()


def execute_with_cudf(nt, *, raise_on_fail: bool = False) -> None:
def execute_with_cudf(nt: NodeTraverser, *, raise_on_fail: bool = False) -> None:
"""
A post optimization callback that attempts to execute the plan with cudf.

Expand Down
13 changes: 7 additions & 6 deletions python/cudf_polars/cudf_polars/containers/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import polars as pl

Expand All @@ -17,6 +17,7 @@
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence, Set

import pyarrow as pa
from typing_extensions import Self

import cudf
Expand Down Expand Up @@ -44,13 +45,13 @@ def copy(self) -> Self:

def to_polars(self) -> pl.DataFrame:
"""Convert to a polars DataFrame."""
return pl.from_arrow(
plc.interop.to_arrow(
self.table,
[plc.interop.ColumnMetadata(name=c.name) for c in self.columns],
)
table: pa.Table = plc.interop.to_arrow(
self.table,
[plc.interop.ColumnMetadata(name=c.name) for c in self.columns],
)

return cast(pl.DataFrame, pl.from_arrow(table))

@cached_property
def column_names_set(self) -> frozenset[str]:
"""Return the column names as a set."""
Expand Down
55 changes: 39 additions & 16 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,14 @@ def is_equal(self, other: Any) -> bool:
other.children
)

def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Equality of expressions."""
if type(self) != type(other) or hash(self) != hash(other):
return False
else:
return self.is_equal(other)

def __ne__(self, other) -> bool:
def __ne__(self, other: Any) -> bool:
"""Inequality of expressions."""
return not self.__eq__(other)

Expand Down Expand Up @@ -285,6 +285,8 @@ class NamedExpr:
# when evaluating expressions themselves, only when constructing
# named return values in dataframe (IR) nodes.
__slots__ = ("name", "value")
value: Expr
name: str

def __init__(self, name: str, value: Expr) -> None:
self.name = name
Expand All @@ -298,15 +300,15 @@ def __repr__(self) -> str:
"""Repr of the expression."""
return f"NamedExpr({self.name}, {self.value}"

def __eq__(self, other) -> bool:
def __eq__(self, other: Any) -> bool:
"""Equality of two expressions."""
return (
type(self) is type(other)
and self.name == other.name
and self.value == other.value
)

def __ne__(self, other) -> bool:
def __ne__(self, other: Any) -> bool:
"""Inequality of expressions."""
return not self.__eq__(other)

Expand Down Expand Up @@ -344,9 +346,10 @@ def collect_agg(self, *, depth: int) -> AggInfo:
class Literal(Expr):
__slots__ = ("value",)
_non_child = ("dtype", "value")
value: pa.Scalar
value: pa.Scalar[Any]
children: tuple[()]

def __init__(self, dtype: plc.DataType, value: pa.Scalar) -> None:
def __init__(self, dtype: plc.DataType, value: pa.Scalar[Any]) -> None:
super().__init__(dtype)
assert value.type == plc.interop.to_arrow(dtype)
self.value = value
Expand All @@ -367,6 +370,7 @@ class Col(Expr):
__slots__ = ("name",)
_non_child = ("dtype", "name")
name: str
children: tuple[()]
mroeschke marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, dtype: plc.DataType, name: str) -> None:
self.dtype = dtype
Expand All @@ -388,6 +392,8 @@ def collect_agg(self, *, depth: int) -> AggInfo:


class Len(Expr):
children: tuple[()]

def do_evaluate(
self,
df: DataFrame,
Expand All @@ -410,8 +416,15 @@ def collect_agg(self, *, depth: int) -> AggInfo:
class BooleanFunction(Expr):
__slots__ = ("name", "options", "children")
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(self, dtype: plc.DataType, name: str, options: tuple, *children: Expr):
def __init__(
self,
dtype: plc.DataType,
name: pl_expr.BooleanFunction,
options: tuple[Any, ...],
*children: Expr,
) -> None:
super().__init__(dtype)
self.options = options
self.name = name
Expand Down Expand Up @@ -610,14 +623,15 @@ def do_evaluate(
class StringFunction(Expr):
__slots__ = ("name", "options", "children")
_non_child = ("dtype", "name", "options")
children: tuple[Expr, ...]

def __init__(
self,
dtype: plc.DataType,
name: pl_expr.StringFunction,
options: tuple,
options: tuple[Any, ...],
*children: Expr,
):
) -> None:
super().__init__(dtype)
self.options = options
self.name = name
Expand Down Expand Up @@ -661,10 +675,11 @@ def do_evaluate(
class Sort(Expr):
__slots__ = ("options", "children")
_non_child = ("dtype", "options")
children: tuple[Expr]

def __init__(
self, dtype: plc.DataType, options: tuple[bool, bool, bool], column: Expr
):
) -> None:
super().__init__(dtype)
self.options = options
self.children = (column,)
Expand Down Expand Up @@ -696,14 +711,15 @@ def do_evaluate(
class SortBy(Expr):
__slots__ = ("options", "children")
_non_child = ("dtype", "options")
children: tuple[Expr, ...]

def __init__(
self,
dtype: plc.DataType,
options: tuple[bool, tuple[bool], tuple[bool]],
column: Expr,
*by: Expr,
):
) -> None:
super().__init__(dtype)
self.options = options
self.children = (column, *by)
Expand Down Expand Up @@ -734,8 +750,9 @@ def do_evaluate(
class Gather(Expr):
__slots__ = ("children",)
_non_child = ("dtype",)
children: tuple[Expr, Expr]

def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr):
def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr) -> None:
super().__init__(dtype)
self.children = (values, indices)

Expand Down Expand Up @@ -775,6 +792,7 @@ def do_evaluate(
class Filter(Expr):
__slots__ = ("children",)
_non_child = ("dtype",)
children: tuple[Expr, Expr]

def __init__(self, dtype: plc.DataType, values: Expr, indices: Expr):
super().__init__(dtype)
Expand All @@ -801,8 +819,9 @@ def do_evaluate(
class RollingWindow(Expr):
__slots__ = ("options", "children")
_non_child = ("dtype", "options")
children: tuple[Expr]

def __init__(self, dtype: plc.DataType, options: Any, agg: Expr):
def __init__(self, dtype: plc.DataType, options: Any, agg: Expr) -> None:
super().__init__(dtype)
self.options = options
self.children = (agg,)
Expand All @@ -811,8 +830,9 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr):
class GroupedRollingWindow(Expr):
__slots__ = ("options", "children")
_non_child = ("dtype", "options")
children: tuple[Expr, ...]

def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr):
def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr) -> None:
super().__init__(dtype)
self.options = options
self.children = (agg, *by)
Expand All @@ -821,8 +841,9 @@ def __init__(self, dtype: plc.DataType, options: Any, agg: Expr, *by: Expr):
class Cast(Expr):
__slots__ = ("children",)
_non_child = ("dtype",)
children: tuple[Expr]

def __init__(self, dtype: plc.DataType, value: Expr):
def __init__(self, dtype: plc.DataType, value: Expr) -> None:
super().__init__(dtype)
self.children = (value,)

Expand All @@ -848,6 +869,7 @@ def collect_agg(self, *, depth: int) -> AggInfo:
class Agg(Expr):
__slots__ = ("name", "options", "op", "request", "children")
_non_child = ("dtype", "name", "options")
children: tuple[Expr]

def __init__(
self, dtype: plc.DataType, name: str, options: Any, value: Expr
Expand Down Expand Up @@ -1007,7 +1029,7 @@ def _last(self, column: Column) -> Column:

def do_evaluate(
self,
df,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
Expand All @@ -1022,6 +1044,7 @@ def do_evaluate(
class BinOp(Expr):
__slots__ = ("op", "children")
_non_child = ("dtype", "op")
children: tuple[Expr, Expr]

def __init__(
self,
Expand Down
Loading
Loading