Skip to content

Commit

Permalink
Better ergonomics for plan visitor too
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Apr 22, 2024
1 parent a32119c commit d86957c
Showing 1 changed file with 34 additions and 40 deletions.
74 changes: 34 additions & 40 deletions python/cudf_polars/cudf_polars/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,25 @@ class PlanVisitor(NamedTuple):
profiler: ExecutionProfiler | NoopProfiler
expr_visitor: ExprVisitor

def node(self, n: int) -> Plan:
def __call__(self, n: int | None = None) -> DataFrame:
"""
Translate node to python object.
Evaluate a plan node to produce a dataframe.
Parameters
----------
n
Node to replace.
Node to evaluate (optional), if not provided uses the internal
visitor's "current" node.
Returns
-------
Python representation of the node.
New dataframe giving the evaluation of the plan.
"""
return self.visitor.view_node(n)
if n is None:
node = self.visitor.view_current_node()
else:
node = self.visitor.view_node(n)
return _execute_plan(node, self)

def record(self, name: str):
"""
Expand Down Expand Up @@ -160,19 +165,15 @@ def execute_plan(
-------
DataFrame representing the execution of the plan
"""
plan = visitor.view_current_node()
profiler: ExecutionProfiler | NoopProfiler
if profile:
profiler = ExecutionProfiler()
result = _execute_plan(
plan, PlanVisitor(visitor, {}, profiler, ExprVisitor(visitor))
)
return result, profiler
return PlanVisitor(visitor, {}, profiler, ExprVisitor(visitor))(
n=None
), profiler
else:
profiler = NoopProfiler()
return _execute_plan(
plan, PlanVisitor(visitor, {}, profiler, ExprVisitor(visitor))
)
return PlanVisitor(visitor, {}, profiler, ExprVisitor(visitor))(n=None)


@singledispatch
Expand Down Expand Up @@ -273,9 +274,7 @@ def _cache(plan: nodes.Cache, visitor: PlanVisitor):
try:
return cache[key]
except KeyError:
return cache.setdefault(
key, _execute_plan(visitor.node(plan.input), visitor)
)
return cache.setdefault(key, visitor(plan.input))


@_execute_plan.register
Expand Down Expand Up @@ -315,7 +314,7 @@ def _dataframescan(plan: nodes.DataFrameScan, visitor: PlanVisitor):

@_execute_plan.register
def _select(plan: nodes.Select, visitor: PlanVisitor):
context = _execute_plan(visitor.node(plan.input), visitor)
context = visitor(plan.input)
with visitor.record("select"):
# TODO: loses sortedness properties
for cse in plan.cse_expr:
Expand All @@ -334,7 +333,7 @@ def _select(plan: nodes.Select, visitor: PlanVisitor):
def _groupby(plan: nodes.GroupBy, visitor: PlanVisitor):
name = "group_by" if plan.options.rolling is None else "rolling"
# Input frame to groupby
context = _execute_plan(visitor.node(plan.input), visitor)
context = visitor(plan.input)
agg_names = [e.output_name for e in plan.aggs]
agg_nodes = [e.node for e in plan.aggs]
with visitor.record(name):
Expand Down Expand Up @@ -414,8 +413,8 @@ def _groupby(plan: nodes.GroupBy, visitor: PlanVisitor):

@_execute_plan.register
def _join(plan: nodes.Join, visitor: PlanVisitor):
left = _execute_plan(visitor.node(plan.input_left), visitor)
right = _execute_plan(visitor.node(plan.input_right), visitor)
left = visitor(plan.input_left)
right = visitor(plan.input_right)
with visitor.record("join"):
left_on = plc.Table(
[visitor.expr_visitor(e.node, left) for e in plan.left_on]
Expand Down Expand Up @@ -509,7 +508,7 @@ def _join(plan: nodes.Join, visitor: PlanVisitor):

@_execute_plan.register
def _hstack(plan: nodes.HStack, visitor: PlanVisitor):
result = _execute_plan(visitor.node(plan.input), visitor)
result = visitor(plan.input)
with visitor.record("hstack"):
columns = {
e.output_name: visitor.expr_visitor(e.node, result)
Expand All @@ -521,7 +520,7 @@ def _hstack(plan: nodes.HStack, visitor: PlanVisitor):

@_execute_plan.register
def _distinct(plan: nodes.Distinct, visitor: PlanVisitor):
result = _execute_plan(visitor.node(plan.input), visitor)
result = visitor(plan.input)
with visitor.record("distinct"):
(keep, subset, maintain_order, zlice) = plan.options
keep = {
Expand Down Expand Up @@ -571,7 +570,7 @@ def _distinct(plan: nodes.Distinct, visitor: PlanVisitor):

@_execute_plan.register
def _sort(plan: nodes.Sort, visitor: PlanVisitor):
result = _execute_plan(visitor.node(plan.input), visitor)
result = visitor(plan.input)
with visitor.record("sort"):
input_col_ids = set(map(id, result.values()))
sort_keys = [
Expand Down Expand Up @@ -618,22 +617,22 @@ def _sort(plan: nodes.Sort, visitor: PlanVisitor):

@_execute_plan.register
def _slice(plan: nodes.Slice, visitor: PlanVisitor):
result = _execute_plan(visitor.node(plan.input), visitor)
result = visitor(plan.input)
with visitor.record("slice"):
return result.slice(plan.offset, plan.len)


@_execute_plan.register
def _filter(plan: nodes.Filter, visitor: PlanVisitor):
result = _execute_plan(visitor.node(plan.input), visitor)
result = visitor(plan.input)
with visitor.record("filter"):
mask = visitor.expr_visitor(plan.predicate.node, result)
return result.filter(mask)


@_execute_plan.register
def _simple_projection(plan: nodes.SimpleProjection, visitor: PlanVisitor):
result = _execute_plan(visitor.node(plan.input), visitor)
result = visitor(plan.input)
schema = plan.columns
with visitor.record("simple_projection"):
return DataFrame({name: result[name] for name in schema})
Expand All @@ -647,7 +646,7 @@ def _map_function(plan: nodes.MapFunction, visitor: PlanVisitor):
(to_unnest,) = args
raise NotImplementedError("unnest")
elif typ == "drop_nulls":
context = _execute_plan(visitor.node(plan.input), visitor)
context = visitor(plan.input)
with profiler:
(subset,) = args
subset = set(subset)
Expand All @@ -663,7 +662,7 @@ def _map_function(plan: nodes.MapFunction, visitor: PlanVisitor):
)
elif typ == "rechunk":
# No-op in a non-chunked setting
return _execute_plan(visitor.node(plan.input), visitor)
return visitor(plan.input)
elif typ == "merge_sorted":
pieces = plan.input
# merge_sorted operates on Union inputs
Expand All @@ -675,10 +674,7 @@ def _map_function(plan: nodes.MapFunction, visitor: PlanVisitor):
# We don't have that luxury so we assume we have a union, and
# evaluate the pieces.
assert isinstance(pieces, nodes.Union)
first, *rest = (
_execute_plan(visitor.node(piece), visitor)
for piece in pieces.inputs
)
first, *rest = (visitor(piece) for piece in pieces.inputs)
with profiler:
(key_column,) = args
column_names = first.names()
Expand All @@ -705,12 +701,12 @@ def _map_function(plan: nodes.MapFunction, visitor: PlanVisitor):
),
)
elif typ == "rename":
context = _execute_plan(visitor.node(plan.input), visitor)
context = visitor(plan.input)
with profiler:
old_names, new_names, _ = args
return context.rename(dict(zip(old_names, new_names, strict=True)))
elif typ == "explode":
context = _execute_plan(visitor.node(plan.input), visitor)
context = visitor(plan.input)
with profiler:
column_names, schema = args
if len(column_names) > 1:
Expand All @@ -734,9 +730,7 @@ def _map_function(plan: nodes.MapFunction, visitor: PlanVisitor):

@_execute_plan.register
def _union(plan: nodes.Union, visitor: PlanVisitor):
input_tables = [
_execute_plan(visitor.node(p), visitor) for p in plan.inputs
]
input_tables = [visitor(p) for p in plan.inputs]
with visitor.record("union"):
# ordered set
all_names = list(
Expand Down Expand Up @@ -772,23 +766,23 @@ def _hconcat(plan: nodes.HConcat, visitor: PlanVisitor):
return DataFrame(
reduce(
operator.or_,
(_execute_plan(visitor.node(p), visitor) for p in plan.inputs),
(visitor(p) for p in plan.inputs),
{},
)
)


@_execute_plan.register
def _extcontext(plan: nodes.ExtContext, visitor: PlanVisitor):
result = _execute_plan(visitor.node(plan.input), visitor)
result = visitor(plan.input)
# TODO: This is not right, e.g. if there is a projection that
# selects some subset of the columns. But it seems it is not
# pushed inside the ExtContext node, so we need some other way of
# handling that.
return DataFrame(
reduce(
operator.or_,
(_execute_plan(visitor.node(p), visitor) for p in plan.contexts),
(visitor(p) for p in plan.contexts),
result,
)
)
Expand Down

0 comments on commit d86957c

Please sign in to comment.