Skip to content

Commit

Permalink
feat: add support for ibis 5.x
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth committed May 17, 2023
1 parent ec8160a commit 22c56b6
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 53 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ jobs:
ibis-version:
- "3.2"
- "4.1"
- "5.1"
include:
- os: windows-latest
python-version: "3.10"
Expand Down
14 changes: 11 additions & 3 deletions ibis_substrait/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def extension_lookup(
if function_extension is None:
sigkey = tuple(
[
IBIS_SUBSTRAIT_TYPE_MAPPING[arg.op().output_dtype.name]
IBIS_SUBSTRAIT_TYPE_MAPPING[arg.op().output_dtype.name] # type: ignore
for arg in op.args
if arg is not None and isinstance(arg, (ir.Expr, ops.Node))
]
Expand Down Expand Up @@ -183,7 +183,7 @@ def register_extension_uri(self, uri: str) -> ste.SimpleExtensionURI:

return extension_uri

def compile(self, expr: ir.TableExpr, **kwargs: Any) -> stp.Plan:
def compile(self, expr: ir.Table, **kwargs: Any) -> stp.Plan:
"""Construct a Substrait plan from an ibis table expression."""
from .translate import translate

Expand Down Expand Up @@ -271,4 +271,12 @@ def _get_fields(dtype: dt.DataType) -> Iterator[tuple[str | None, dt.DataType]]:
yield None, dtype.value_type
yield None, dtype.key_type
elif isinstance(dtype, dt.Struct):
yield from reversed(list(dtype.pairs.items()))
# Ibis 3
pairs = getattr(dtype, "pairs", None)
if pairs is None:
# Ibis 4 and 5
pairs = getattr(dtype, "fields", None)

if pairs is None:
raise AttributeError
yield from reversed(list(pairs.items()))
123 changes: 73 additions & 50 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,11 @@
# Python <=3.9
from typing_extensions import TypeAlias

IBIS_4 = False
if version.parse(ibis.__version__) >= version.parse("4.0.0"):
IBIS_4 = True
IBIS_GTE_4 = version.parse(ibis.__version__) >= version.parse("4.0.0")
IBIS_GTE_5 = version.parse(ibis.__version__) >= version.parse("5.0.0")


if IBIS_4:
if IBIS_GTE_4:
import warnings

from ibis.common.graph import toposort
Expand All @@ -55,19 +54,19 @@
category=FutureWarning,
)
else:
from ibis.util import to_op_dag
from ibis.util import to_op_dag # type: ignore

# There is no ops.CountStar in Ibis 3.x but to register it for 4.x below, it
# can't be undefined here.
# We remap it to ops.Count just to avoid an attribute error, it will never
# be used to route in Ibis 3.x because it doesn't exist.
ops.CountStar = ops.Count
ops.CountStar = ops.Count # type: ignore
# ops.ValueOp renamed to ops.Value in Ibix 3.2
# We manually add ops.Value here for Ibis 3.0 compatibility
if hasattr(ops, "ValueOp"):
ops.Value = ops.ValueOp
ops.Value = ops.ValueOp # type: ignore
if hasattr(ops, "BinaryOp"):
ops.Binary = ops.BinaryOp
ops.Binary = ops.BinaryOp # type: ignore

T = TypeVar("T")

Expand Down Expand Up @@ -255,13 +254,11 @@ def _literal(
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
dtype = op.output_dtype
value = op.value
if value is None:
if op.value is None:
return stalg.Expression(
literal=stalg.Expression.Literal(null=translate(dtype, **kwargs))
literal=stalg.Expression.Literal(null=translate(op.output_dtype, **kwargs)) # type: ignore
)
return stalg.Expression(literal=translate_literal(dtype, op.value))
return stalg.Expression(literal=translate_literal(op.output_dtype, op.value)) # type: ignore


@functools.singledispatch
Expand Down Expand Up @@ -464,6 +461,20 @@ def _following_int(offset: int) -> stalg.Expression.WindowFunction.Bound:
)


if IBIS_GTE_5:

@translate_following.register
@translate_preceding.register
def _window_boundary( # type: ignore
boundary: ops.window.WindowBoundary,
) -> stalg.Expression.WindowFunction.Bound:
# new window boundary class in Ibis 5.x
if boundary.preceding:
return translate_preceding(boundary.value.value)
else:
return translate_following(boundary.value.value)


def _translate_window_bounds(
precedes: tuple[int, int] | int | None,
follows: tuple[int, int] | int | None,
Expand Down Expand Up @@ -516,9 +527,9 @@ def alias_op(
return translate(op.arg.op(), compiler=compiler, **kwargs)


@translate.register(ops.ValueOp)
@translate.register(ops.ValueOp) # type: ignore
def value_op(
op: ops.ValueOp,
op: ops.ValueOp, # type: ignore
*,
compiler: SubstraitCompiler,
**kwargs: Any,
Expand All @@ -542,33 +553,44 @@ def value_op(
)


@translate.register(ops.WindowOp)
@translate.register(ops.WindowOp) # type: ignore
def window_op(
op: ops.WindowOp,
op: ops.WindowOp, # type: ignore
*,
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.Expression:
lower_bound, upper_bound = _translate_window_bounds(
op.window.preceding, op.window.following
)
if IBIS_GTE_5:
# Ibis >= 5.x
window_gb = op.frame.group_by
window_ob = op.frame.order_by
start = op.frame.start
end = op.frame.end
func = op.func
func_args = op.func.args
else:
window_gb = op.window._group_by
window_ob = op.window._order_by
start = op.window.preceding
end = op.window.following
func = op.expr.op()
func_args = op.expr.op().args

lower_bound, upper_bound = _translate_window_bounds(start, end)

return stalg.Expression(
window_function=stalg.Expression.WindowFunction(
function_reference=compiler.function_id(op.expr.op()),
partitions=[
translate(gb, compiler=compiler, **kwargs) for gb in op.window._group_by
],
sorts=[
translate(ob, compiler=compiler, **kwargs) for ob in op.window._order_by
],
function_reference=compiler.function_id(func),
partitions=[translate(gb, compiler=compiler, **kwargs) for gb in window_gb],
sorts=[translate(ob, compiler=compiler, **kwargs) for ob in window_ob],
output_type=translate(op.output_dtype),
phase=stalg.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT,
arguments=[
stalg.FunctionArgument(
value=translate(arg, compiler=compiler, **kwargs)
)
for arg in op.expr.op().args
if isinstance(arg, ir.Expr)
for arg in func_args
if isinstance(arg, (ir.Expr, ops.Value))
],
lower_bound=lower_bound,
upper_bound=upper_bound,
Expand All @@ -586,7 +608,7 @@ def _reduction(
return stalg.AggregateFunction(
function_reference=compiler.function_id(op),
arguments=[
stalg.FunctionArgument(value=translate(op.arg, compiler=compiler, **kwargs))
stalg.FunctionArgument(value=translate(op.arg, compiler=compiler, **kwargs)) # type: ignore
],
sorts=[], # TODO: ibis doesn't support this yet
phase=stalg.AggregationPhase.AGGREGATION_PHASE_INITIAL_TO_RESULT,
Expand All @@ -602,10 +624,10 @@ def _count(
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.AggregateFunction:
translated_args = []
translated_args: list[stalg.FunctionArgument] = []
# TODO: remove this expr
arg = op.arg.op().to_expr()
if not isinstance(arg, (ir.TableExpr, ops.PhysicalTable)):
if not isinstance(arg, (ir.TableExpr, ops.PhysicalTable)): # type: ignore
translated_args.append(
stalg.FunctionArgument(value=translate(arg, compiler=compiler, **kwargs))
)
Expand Down Expand Up @@ -732,7 +754,7 @@ def unbound_table(
)


def _get_child_relation_field_offsets(table: ir.TableExpr) -> dict[ops.TableNode, int]:
def _get_child_relation_field_offsets(table: ir.Table) -> dict[ops.TableNode, int]:
"""Return the offset of each of table's fields.
This function calculates the starting index of a relations fields, as if
Expand Down Expand Up @@ -876,14 +898,14 @@ def selection(


def _get_selections(op: ops.Selection) -> Sequence[ir.Column]:
if IBIS_4:
if IBIS_GTE_4:
# projection / emit
selections = [
col
for sel in (x.to_expr() for x in op.selections) # map ops to exprs
for col in (
map(sel.__getitem__, sel.columns)
if isinstance(sel, ir.TableExpr)
if isinstance(sel, ir.TableExpr) # type: ignore
else [sel]
)
]
Expand All @@ -893,17 +915,17 @@ def _get_selections(op: ops.Selection) -> Sequence[ir.Column]:
col
for sel in op.selections
for col in (
sel.get_columns(sel.columns) if isinstance(sel, ir.TableExpr) else [sel]
sel.get_columns(sel.columns) if isinstance(sel, ir.TableExpr) else [sel] # type: ignore
)
]

return selections


def _find_parent_tables(op: ops.Selection) -> set[ir.Table]:
def _find_parent_tables(op: ops.Selection) -> set[ops.PhysicalTable]:
# TODO: settle on a better source table definition than "PhysicalTable with
# a schema"
if IBIS_4:
if IBIS_GTE_4:
source_tables = {
t
for t in toposort(op).keys()
Expand All @@ -912,7 +934,7 @@ def _find_parent_tables(op: ops.Selection) -> set[ir.Table]:
else:
source_tables = {
t
for t in to_op_dag(op.to_expr()).keys()
for t in to_op_dag(op.to_expr()).keys() # type: ignore
if isinstance(t, ops.PhysicalTable) and hasattr(t, "schema")
}

Expand Down Expand Up @@ -962,8 +984,9 @@ def join(
**kwargs: Any,
) -> stalg.Rel:
child_rel_field_offsets = kwargs.pop("child_rel_field_offsets", None)
expr = op.to_expr() # type: ignore
child_rel_field_offsets = (
child_rel_field_offsets or _get_child_relation_field_offsets(op.to_expr())
child_rel_field_offsets or _get_child_relation_field_offsets(expr)
)
predicates = [pred.op().to_expr() for pred in op.predicates]
return stalg.Rel(
Expand Down Expand Up @@ -1327,7 +1350,7 @@ def _not_exists_subquery(

return stalg.Expression(
scalar_function=stalg.Expression.ScalarFunction(
function_reference=compiler.function_id(ops.Not(op.to_expr())),
function_reference=compiler.function_id(ops.Not(op.to_expr())), # type: ignore
output_type=translate(op.output_dtype),
arguments=[
stalg.FunctionArgument(
Expand Down Expand Up @@ -1362,7 +1385,7 @@ def _floor_ceil_cast(
stalg.FunctionArgument(
value=translate(arg, compiler=compiler, **kwargs)
)
for arg in op.func_args
for arg in op.func_args # type: ignore
if isinstance(arg, (ir.Expr, ops.Value))
],
)
Expand Down Expand Up @@ -1449,15 +1472,15 @@ def _upcast_bin_op(op: ops.Binary) -> ops.Binary:
if left == right:
return op
elif dt.castable(left, right, upcast=True):
if IBIS_4:
return type(op)(ops.Cast(op.left, to=right), op.right)
if IBIS_GTE_4:
return type(op)(ops.Cast(op.left, to=right), op.right) # type: ignore
else:
return type(op)(op.left.cast(right), op.right)
return type(op)(op.left.cast(right), op.right) # type: ignore
elif dt.castable(right, left, upcast=True):
if IBIS_4:
return type(op)(op.left, ops.Cast(op.right, to=left))
if IBIS_GTE_4:
return type(op)(op.left, ops.Cast(op.right, to=left)) # type: ignore
else:
return type(op)(op.left, op.right.cast(left))
return type(op)(op.left, op.right.cast(left)) # type: ignore
else:
raise TypeError(
f"binop {type(op).__name__} called with incompatible types {left=} {right=}"
Expand All @@ -1477,9 +1500,9 @@ def _upcast_bin_op(op: ops.Binary) -> ops.Binary:
@_upcast.register(ops.RPad)
def _upcast_string_op(op: string_op) -> string_op:
# Substrait wants Int32 for all numeric args to string functions
if IBIS_4:
if IBIS_GTE_4:
casted_args = [
ops.Cast(newop, to=dt.Int32())
ops.Cast(newop, to=dt.Int32()) # type: ignore
if isinstance(newop.output_dtype, dt.SignedInteger)
else newop
for newop in op.args
Expand Down
6 changes: 6 additions & 0 deletions ibis_substrait/tests/compiler/test_decompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from ibis_substrait.compiler.decompile import decompile, decompile_schema
from ibis_substrait.proto.substrait.ibis import type_pb2 as stt

ibis5 = pytest.mark.skipif(
version.parse(ibis.__version__) >= version.parse("5.0.0"),
reason="Not extending decompiler support further",
)


@pytest.fixture
def t():
Expand Down Expand Up @@ -41,6 +46,7 @@ def q():
return ibis.table([("e", "string"), ("f", "int64")], name="q")


@ibis5
@pytest.mark.parametrize(
"expr_fn",
[
Expand Down

0 comments on commit 22c56b6

Please sign in to comment.