Skip to content

Commit

Permalink
move back lower_ir_node default
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Dec 2, 2024
1 parent 503ad59 commit 043d268
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,33 @@
from cudf_polars.experimental.dispatch import LowerIRTransformer


@lower_ir_node.register(IR)
def _(ir: IR, rec: LowerIRTransformer) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
# Default logic - Requires single partition

if len(ir.children) == 0:
# Default leaf node has single partition
return ir, {
ir: PartitionInfo(count=1)
} # pragma: no cover; Missed by pylibcudf executor

# Lower children
children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
partition_info = reduce(operator.or_, _partition_info)

# Check that child partitioning is supported
if any(partition_info[c].count > 1 for c in children):
raise NotImplementedError(
f"Class {type(ir)} does not support multiple partitions."
) # pragma: no cover

# Return reconstructed node and partition-info dict
partition = PartitionInfo(count=1)
new_node = ir.reconstruct(children)
partition_info[new_node] = partition
return new_node, partition_info


def lower_ir_graph(ir: IR) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
"""
Rewrite an IR graph and extract partitioning information.
Expand Down Expand Up @@ -107,32 +134,6 @@ def evaluate_dask(ir: IR) -> DataFrame:
return get(graph, key)


@lower_ir_node.register(IR)
def _(ir: IR, rec: LowerIRTransformer) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
# Single-partition fall-back for lower_ir_node
if len(ir.children) == 0:
# Default leaf node has single partition
return ir, {
ir: PartitionInfo(count=1)
} # pragma: no cover; Missed by pylibcudf executor

# Lower children
children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
partition_info = reduce(operator.or_, _partition_info)

# Check that child partitioning is supported
if any(partition_info[c].count > 1 for c in children):
raise NotImplementedError(
f"Class {type(ir)} does not support multiple partitions."
) # pragma: no cover

# Return reconstructed node and partition-info dict
partition = PartitionInfo(count=1)
new_node = ir.reconstruct(children)
partition_info[new_node] = partition
return new_node, partition_info


@generate_ir_tasks.register(IR)
def _(
ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
Expand Down

0 comments on commit 043d268

Please sign in to comment.