diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 38bfc857826..9808bdcab7a 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -243,19 +243,12 @@ class Select(IR): df: IR """Input dataframe.""" - cse: list[expr.NamedExpr] - """ - List of common subexpressions that will appear in the selected expressions. - - These must be evaluated before the returned expressions. - """ expr: list[expr.NamedExpr] """List of expressions to evaluate to form the new dataframe.""" def evaluate(self, *, cache: dict[int, DataFrame]): """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) - df = df.with_columns([e.evaluate(df) for e in self.cse]) return DataFrame([e.evaluate(df) for e in self.expr]) @@ -541,20 +534,13 @@ class HStack(IR): df: IR """Input dataframe.""" - cse: list[expr.NamedExpr] - """ - List of common subexpressions that will appear in the selected expressions. - - These must be evaluated before the returned expressions. - """ columns: list[expr.NamedExpr] """List of expressions to produce new columns.""" def evaluate(self, *, cache: dict[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" df = self.df.evaluate(cache=cache) - ctx = df.copy().with_columns([e.evaluate(df) for e in self.cse]) - return df.with_columns([c.evaluate(ctx) for c in self.columns]) + return df.with_columns([c.evaluate(df) for c in self.columns]) @dataclass(slots=True) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 17596e354fb..4837473ca5b 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -22,7 +22,21 @@ class set_node(AbstractContextManager): - """Run a block with current node set in the visitor.""" + """ + Run a block with current node set in the visitor. + + Parameters + ---------- + visitor + The internal Rust visitor object + n + The node to set as the current root. + + Notes + ----- + This is useful for translating expressions with a given node + active, restoring the node when the block exits. + """ __slots__ = ("n", "visitor") @@ -94,17 +108,16 @@ def _( def _(node: pl_ir.Select, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) - cse_exprs = [translate_named_expr(visitor, n=e) for e in node.cse_expr] - exprs = [translate_named_expr(visitor, n=e) for e in node.expr] - return ir.Select(schema, inp, cse_exprs, exprs) + exprs = [translate_named_expr(visitor, n=e) for e in node.expr] + return ir.Select(schema, inp, exprs) @_translate_ir.register def _(node: pl_ir.GroupBy, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) - aggs = [translate_named_expr(visitor, n=e) for e in node.aggs] - keys = [translate_named_expr(visitor, n=e) for e in node.keys] + aggs = [translate_named_expr(visitor, n=e) for e in node.aggs] + keys = [translate_named_expr(visitor, n=e) for e in node.keys] return ir.GroupBy( schema, inp, @@ -133,16 +146,15 @@ def _(node: pl_ir.Join, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: def _(node: pl_ir.HStack, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) - cse_exprs = [translate_named_expr(visitor, n=e) for e in node.cse_exprs] - exprs = [translate_named_expr(visitor, n=e) for e in node.exprs] - return ir.HStack(schema, inp, cse_exprs, exprs) + exprs = [translate_named_expr(visitor, n=e) for e in node.exprs] + return ir.HStack(schema, inp, exprs) @_translate_ir.register def _(node: pl_ir.Reduce, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) - exprs = [translate_named_expr(visitor, n=e) for e in node.expr] + exprs = [translate_named_expr(visitor, n=e) for e in node.expr] return ir.Reduce(schema, inp, exprs) @@ -159,7 +171,7 @@ def _(node: pl_ir.Distinct, visitor: Any, schema: dict[str, plc.DataType]) -> ir def _(node: pl_ir.Sort, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) - by = [translate_named_expr(visitor, n=e) for e in node.by_column] + by = [translate_named_expr(visitor, n=e) for e in node.by_column] return ir.Sort(schema, inp, by, node.sort_options, node.slice) @@ -172,7 +184,7 @@ def _(node: pl_ir.Slice, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR def _(node: pl_ir.Filter, visitor: Any, schema: dict[str, plc.DataType]) -> ir.IR: with set_node(visitor, node.input): inp = translate_ir(visitor, n=None) - mask = translate_named_expr(visitor, n=node.predicate) + mask = translate_named_expr(visitor, n=node.predicate) return ir.Filter(schema, inp, mask) @@ -261,6 +273,13 @@ def translate_named_expr(visitor: Any, *, n: pl_expr.PyExprIR) -> expr.NamedExpr ------- Translated IR object. + Notes + ----- + The datatype of the internal expression will be obtained from the + visitor by calling ``get_dtype``, for this to work properly, the + caller should arrange that the expression is translated with the + node that it references "active" for the visitor (see :class:`set_node`). + Raises ------ NotImplementedError if any translation fails due to unsupported functionality.