Skip to content

Commit

Permalink
No more CSE exprs
Browse files Browse the repository at this point in the history
Expressions must now be translated with the node which is to provide
the schema active.
  • Loading branch information
wence- committed Jun 4, 2024
1 parent dc43d11 commit 0521592
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 27 deletions.
16 changes: 1 addition & 15 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,12 @@ class Select(IR):

df: IR
"""Input dataframe."""
cse: list[expr.NamedExpr]
"""
List of common subexpressions that will appear in the selected expressions.
These must be evaluated before the returned expressions.
"""
expr: list[expr.NamedExpr]
"""List of expressions to evaluate to form the new dataframe."""

def evaluate(self, *, cache: dict[int, DataFrame]):
"""Evaluate and return a dataframe."""
df = self.df.evaluate(cache=cache)
df = df.with_columns([e.evaluate(df) for e in self.cse])
return DataFrame([e.evaluate(df) for e in self.expr])


Expand Down Expand Up @@ -541,20 +534,13 @@ class HStack(IR):

df: IR
"""Input dataframe."""
cse: list[expr.NamedExpr]
"""
List of common subexpressions that will appear in the selected expressions.
These must be evaluated before the returned expressions.
"""
columns: list[expr.NamedExpr]
"""List of expressions to produce new columns."""

def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
df = self.df.evaluate(cache=cache)
ctx = df.copy().with_columns([e.evaluate(df) for e in self.cse])
return df.with_columns([c.evaluate(ctx) for c in self.columns])
return df.with_columns([c.evaluate(df) for c in self.columns])


@dataclass(slots=True)
Expand Down
43 changes: 31 additions & 12 deletions python/cudf_polars/cudf_polars/dsl/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,21 @@


class set_node(AbstractContextManager):
"""Run a block with current node set in the visitor."""
"""
Run a block with current node set in the visitor.
Parameters
----------
visitor
The internal Rust visitor object
n
The node to set as the current root.
Notes
-----
This is useful for translating expressions with a given node
active, restoring the node when the block exits.
"""

__slots__ = ("n", "visitor")

Expand Down Expand Up @@ -94,17 +108,16 @@ def _(
def _(node: pl_ir.Select, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
cse_exprs = [translate_named_expr(visitor, n=e) for e in node.cse_expr]
exprs = [translate_named_expr(visitor, n=e) for e in node.expr]
return ir.Select(schema, inp, cse_exprs, exprs)
exprs = [translate_named_expr(visitor, n=e) for e in node.expr]
return ir.Select(schema, inp, exprs)


@_translate_ir.register
def _(node: pl_ir.GroupBy, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
aggs = [translate_named_expr(visitor, n=e) for e in node.aggs]
keys = [translate_named_expr(visitor, n=e) for e in node.keys]
aggs = [translate_named_expr(visitor, n=e) for e in node.aggs]
keys = [translate_named_expr(visitor, n=e) for e in node.keys]
return ir.GroupBy(
schema,
inp,
Expand Down Expand Up @@ -133,16 +146,15 @@ def _(node: pl_ir.Join, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR:
def _(node: pl_ir.HStack, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
cse_exprs = [translate_named_expr(visitor, n=e) for e in node.cse_exprs]
exprs = [translate_named_expr(visitor, n=e) for e in node.exprs]
return ir.HStack(schema, inp, cse_exprs, exprs)
exprs = [translate_named_expr(visitor, n=e) for e in node.exprs]
return ir.HStack(schema, inp, exprs)


@_translate_ir.register
def _(node: pl_ir.Reduce, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
exprs = [translate_named_expr(visitor, n=e) for e in node.expr]
exprs = [translate_named_expr(visitor, n=e) for e in node.expr]
return ir.Reduce(schema, inp, exprs)


Expand All @@ -159,7 +171,7 @@ def _(node: pl_ir.Distinct, visitor: Any, schema: dict[str, plc.DataType]) -> ir
def _(node: pl_ir.Sort, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
by = [translate_named_expr(visitor, n=e) for e in node.by_column]
by = [translate_named_expr(visitor, n=e) for e in node.by_column]
return ir.Sort(schema, inp, by, node.sort_options, node.slice)


Expand All @@ -172,7 +184,7 @@ def _(node: pl_ir.Slice, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR
def _(node: pl_ir.Filter, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR:
with set_node(visitor, node.input):
inp = translate_ir(visitor, n=None)
mask = translate_named_expr(visitor, n=node.predicate)
mask = translate_named_expr(visitor, n=node.predicate)
return ir.Filter(schema, inp, mask)


Expand Down Expand Up @@ -261,6 +273,13 @@ def translate_named_expr(visitor: Any, *, n: pl_expr.PyExprIR) -> expr.NamedExpr
-------
Translated IR object.
Notes
-----
The datatype of the internal expression will be obtained from the
visitor by calling ``get_dtype``, for this to work properly, the
caller should arrange that the expression is translated with the
node that it references "active" for the visitor (see :class:`set_node`).
Raises
------
NotImplementedError if any translation fails due to unsupported functionality.
Expand Down

0 comments on commit 0521592

Please sign in to comment.