From d86957c93d92201d590e58b137e5182382b254f7 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 22 Apr 2024 12:50:10 +0000 Subject: [PATCH] Better ergonomics for plan visitor too --- python/cudf_polars/cudf_polars/plan.py | 74 ++++++++++++-------------- 1 file changed, 34 insertions(+), 40 deletions(-) diff --git a/python/cudf_polars/cudf_polars/plan.py b/python/cudf_polars/cudf_polars/plan.py index 52fadd353a9..2a70390ad95 100644 --- a/python/cudf_polars/cudf_polars/plan.py +++ b/python/cudf_polars/cudf_polars/plan.py @@ -111,20 +111,25 @@ class PlanVisitor(NamedTuple): profiler: ExecutionProfiler | NoopProfiler expr_visitor: ExprVisitor - def node(self, n: int) -> Plan: + def __call__(self, n: int | None = None) -> DataFrame: """ - Translate node to python object. + Evaluate a plan node to produce a dataframe. Parameters ---------- n - Node to replace. + Node to evaluate (optional), if not provided uses the internal + visitor's "current" node. Returns ------- - Python representation of the node. + New dataframe giving the evaluation of the plan. """ - return self.visitor.view_node(n) + if n is None: + node = self.visitor.view_current_node() + else: + node = self.visitor.view_node(n) + return _execute_plan(node, self) def record(self, name: str): """ @@ -160,19 +165,15 @@ def execute_plan( ------- DataFrame representing the execution of the plan """ - plan = visitor.view_current_node() profiler: ExecutionProfiler | NoopProfiler if profile: profiler = ExecutionProfiler() - result = _execute_plan( - plan, PlanVisitor(visitor, {}, profiler, ExprVisitor(visitor)) - ) - return result, profiler + return PlanVisitor(visitor, {}, profiler, ExprVisitor(visitor))( + n=None + ), profiler else: profiler = NoopProfiler() - return _execute_plan( - plan, PlanVisitor(visitor, {}, profiler, ExprVisitor(visitor)) - ) + return PlanVisitor(visitor, {}, profiler, ExprVisitor(visitor))(n=None) @singledispatch @@ -273,9 +274,7 @@ def _cache(plan: nodes.Cache, visitor: PlanVisitor): try: return cache[key] except KeyError: - return cache.setdefault( - key, _execute_plan(visitor.node(plan.input), visitor) - ) + return cache.setdefault(key, visitor(plan.input)) @_execute_plan.register @@ -315,7 +314,7 @@ def _dataframescan(plan: nodes.DataFrameScan, visitor: PlanVisitor): @_execute_plan.register def _select(plan: nodes.Select, visitor: PlanVisitor): - context = _execute_plan(visitor.node(plan.input), visitor) + context = visitor(plan.input) with visitor.record("select"): # TODO: loses sortedness properties for cse in plan.cse_expr: @@ -334,7 +333,7 @@ def _select(plan: nodes.Select, visitor: PlanVisitor): def _groupby(plan: nodes.GroupBy, visitor: PlanVisitor): name = "group_by" if plan.options.rolling is None else "rolling" # Input frame to groupby - context = _execute_plan(visitor.node(plan.input), visitor) + context = visitor(plan.input) agg_names = [e.output_name for e in plan.aggs] agg_nodes = [e.node for e in plan.aggs] with visitor.record(name): @@ -414,8 +413,8 @@ def _groupby(plan: nodes.GroupBy, visitor: PlanVisitor): @_execute_plan.register def _join(plan: nodes.Join, visitor: PlanVisitor): - left = _execute_plan(visitor.node(plan.input_left), visitor) - right = _execute_plan(visitor.node(plan.input_right), visitor) + left = visitor(plan.input_left) + right = visitor(plan.input_right) with visitor.record("join"): left_on = plc.Table( [visitor.expr_visitor(e.node, left) for e in plan.left_on] @@ -509,7 +508,7 @@ def _join(plan: nodes.Join, visitor: PlanVisitor): @_execute_plan.register def _hstack(plan: nodes.HStack, visitor: PlanVisitor): - result = _execute_plan(visitor.node(plan.input), visitor) + result = visitor(plan.input) with visitor.record("hstack"): columns = { e.output_name: visitor.expr_visitor(e.node, result) @@ -521,7 +520,7 @@ def _hstack(plan: nodes.HStack, visitor: PlanVisitor): @_execute_plan.register def _distinct(plan: nodes.Distinct, visitor: PlanVisitor): - result = _execute_plan(visitor.node(plan.input), visitor) + result = visitor(plan.input) with visitor.record("distinct"): (keep, subset, maintain_order, zlice) = plan.options keep = { @@ -571,7 +570,7 @@ def _distinct(plan: nodes.Distinct, visitor: PlanVisitor): @_execute_plan.register def _sort(plan: nodes.Sort, visitor: PlanVisitor): - result = _execute_plan(visitor.node(plan.input), visitor) + result = visitor(plan.input) with visitor.record("sort"): input_col_ids = set(map(id, result.values())) sort_keys = [ @@ -618,14 +617,14 @@ def _sort(plan: nodes.Sort, visitor: PlanVisitor): @_execute_plan.register def _slice(plan: nodes.Slice, visitor: PlanVisitor): - result = _execute_plan(visitor.node(plan.input), visitor) + result = visitor(plan.input) with visitor.record("slice"): return result.slice(plan.offset, plan.len) @_execute_plan.register def _filter(plan: nodes.Filter, visitor: PlanVisitor): - result = _execute_plan(visitor.node(plan.input), visitor) + result = visitor(plan.input) with visitor.record("filter"): mask = visitor.expr_visitor(plan.predicate.node, result) return result.filter(mask) @@ -633,7 +632,7 @@ def _filter(plan: nodes.Filter, visitor: PlanVisitor): @_execute_plan.register def _simple_projection(plan: nodes.SimpleProjection, visitor: PlanVisitor): - result = _execute_plan(visitor.node(plan.input), visitor) + result = visitor(plan.input) schema = plan.columns with visitor.record("simple_projection"): return DataFrame({name: result[name] for name in schema}) @@ -647,7 +646,7 @@ def _map_function(plan: nodes.MapFunction, visitor: PlanVisitor): (to_unnest,) = args raise NotImplementedError("unnest") elif typ == "drop_nulls": - context = _execute_plan(visitor.node(plan.input), visitor) + context = visitor(plan.input) with profiler: (subset,) = args subset = set(subset) @@ -663,7 +662,7 @@ def _map_function(plan: nodes.MapFunction, visitor: PlanVisitor): ) elif typ == "rechunk": # No-op in a non-chunked setting - return _execute_plan(visitor.node(plan.input), visitor) + return visitor(plan.input) elif typ == "merge_sorted": pieces = plan.input # merge_sorted operates on Union inputs @@ -675,10 +674,7 @@ def _map_function(plan: nodes.MapFunction, visitor: PlanVisitor): # We don't have that luxury so we assume we have a union, and # evaluate the pieces. assert isinstance(pieces, nodes.Union) - first, *rest = ( - _execute_plan(visitor.node(piece), visitor) - for piece in pieces.inputs - ) + first, *rest = (visitor(piece) for piece in pieces.inputs) with profiler: (key_column,) = args column_names = first.names() @@ -705,12 +701,12 @@ def _map_function(plan: nodes.MapFunction, visitor: PlanVisitor): ), ) elif typ == "rename": - context = _execute_plan(visitor.node(plan.input), visitor) + context = visitor(plan.input) with profiler: old_names, new_names, _ = args return context.rename(dict(zip(old_names, new_names, strict=True))) elif typ == "explode": - context = _execute_plan(visitor.node(plan.input), visitor) + context = visitor(plan.input) with profiler: column_names, schema = args if len(column_names) > 1: @@ -734,9 +730,7 @@ def _map_function(plan: nodes.MapFunction, visitor: PlanVisitor): @_execute_plan.register def _union(plan: nodes.Union, visitor: PlanVisitor): - input_tables = [ - _execute_plan(visitor.node(p), visitor) for p in plan.inputs - ] + input_tables = [visitor(p) for p in plan.inputs] with visitor.record("union"): # ordered set all_names = list( @@ -772,7 +766,7 @@ def _hconcat(plan: nodes.HConcat, visitor: PlanVisitor): return DataFrame( reduce( operator.or_, - (_execute_plan(visitor.node(p), visitor) for p in plan.inputs), + (visitor(p) for p in plan.inputs), {}, ) ) @@ -780,7 +774,7 @@ def _hconcat(plan: nodes.HConcat, visitor: PlanVisitor): @_execute_plan.register def _extcontext(plan: nodes.ExtContext, visitor: PlanVisitor): - result = _execute_plan(visitor.node(plan.input), visitor) + result = visitor(plan.input) # TODO: This is not right, e.g. if there is a projection that # selects some subset of the columns. But it seems it is not # pushed inside the ExtContext node, so we need some other way of @@ -788,7 +782,7 @@ def _extcontext(plan: nodes.ExtContext, visitor: PlanVisitor): return DataFrame( reduce( operator.or_, - (_execute_plan(visitor.node(p), visitor) for p in plan.contexts), + (visitor(p) for p in plan.contexts), result, ) )