Skip to content

Commit

Permalink
chore: support ibis 7.x
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth committed Oct 5, 2023
1 parent 47028ae commit c6d77be
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 17 deletions.
52 changes: 36 additions & 16 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@

IBIS_GTE_5 = version.parse(ibis.__version__) >= version.parse("5.0.0")
IBIS_GTE_6 = version.parse(ibis.__version__) >= version.parse("6.0.0")
IBIS_GTE_7 = version.parse(ibis.__version__) >= version.parse("7.0.0")


# When an op gets renamed between major versions, we can assign the old name to this
# DummyOp so that we don't get Attribute errors when the new version goes looking for it
class DummyOp(ops.Value):
pass


try:
from typing import TypeAlias
Expand All @@ -47,6 +55,11 @@
T = TypeVar("T")


if IBIS_GTE_7:
ops.Where = ops.IfElse # type: ignore
ops.Contains = DummyOp # type: ignore


def _nullability(dtype: dt.DataType) -> stt.Type.Nullability.V:
return (
stt.Type.Nullability.NULLABILITY_NULLABLE
Expand Down Expand Up @@ -474,9 +487,9 @@ def _window_boundary( # type: ignore
) -> stalg.Expression.WindowFunction.Bound:
# new window boundary class in Ibis 5.x
if boundary.preceding:
return translate_preceding(boundary.value.value)
return translate_preceding(boundary.value.value) # type: ignore
else:
return translate_following(boundary.value.value)
return translate_following(boundary.value.value) # type: ignore


def _translate_window_bounds(
Expand Down Expand Up @@ -808,7 +821,7 @@ def selection(
)
# filter
if op.predicates:
predicates = [pred.to_expr() for pred in op.predicates]
predicates = [pred.to_expr() for pred in op.predicates] # type: ignore
relation = stalg.Rel(
filter=stalg.FilterRel(
input=relation,
Expand Down Expand Up @@ -902,7 +915,8 @@ def _get_selections(op: ops.Selection) -> Sequence[ir.Column]:
# projection / emit
selections = [
col
for sel in (x.to_expr() for x in op.selections) # map ops to exprs
# map ops to exprs
for sel in (x.to_expr() for x in op.selections) # type: ignore
for col in (
map(sel.__getitem__, sel.columns)
if isinstance(sel, ir.TableExpr) # type: ignore
Expand Down Expand Up @@ -1010,9 +1024,9 @@ def translate_set_op_type(op: ops.SetOp) -> stalg.SetRel.SetOp.V:
)


@translate_set_op_type.register(ops.Union)
@translate_set_op_type.register(ops.Union) # type: ignore
def set_op_type_union(op: ops.Union) -> stalg.SetRel.SetOp.V:
if op.distinct:
if op.distinct: # type: ignore
return stalg.SetRel.SetOp.SET_OP_UNION_DISTINCT
return stalg.SetRel.SetOp.SET_OP_UNION_ALL

Expand Down Expand Up @@ -1055,8 +1069,8 @@ def aggregation(
if op.having:
raise NotImplementedError("`having` not yet implemented")

table = op.table.to_expr()
predicates = [pred.to_expr() for pred in op.predicates]
table = op.table.to_expr() # type: ignore
predicates = [pred.to_expr() for pred in op.predicates] # type: ignore
input = translate(
table.filter(predicates) if predicates else table,
compiler=compiler,
Expand Down Expand Up @@ -1116,9 +1130,9 @@ def _simple_searched_case(
return stalg.Expression(if_then=stalg.Expression.IfThen(ifs=_ifs, **_else))


@translate.register(ops.Where)
@translate.register(ops.Where) # type: ignore
def _where(
op: ops.Where,
op: ops.Where, # type: ignore
*,
compiler: SubstraitCompiler,
**kwargs: Any,
Expand All @@ -1136,9 +1150,9 @@ def _where(
return stalg.Expression(if_then=stalg.Expression.IfThen(ifs=_ifs, **_else))


@translate.register(ops.Contains)
@translate.register(ops.Contains) # type: ignore
def _contains(
op: ops.Contains,
op: ops.Contains, # type: ignore
*,
compiler: SubstraitCompiler,
**kwargs: Any,
Expand All @@ -1158,6 +1172,12 @@ def _contains(
)


if IBIS_GTE_7:
# Contains was decomposed into two separate ops in Ibis 7.x
translate.register(ops.InColumn)(_contains)
translate.register(ops.InValues)(_contains)


@translate.register(ops.Cast)
def _cast(
op: ops.Cast,
Expand Down Expand Up @@ -1229,7 +1249,7 @@ def _floordivide(
**kwargs: Any,
) -> stalg.Expression:
left, right = op.left, op.right
return translate((left / right).floor(), compiler=compiler, **kwargs)
return translate((left / right).floor(), compiler=compiler, **kwargs) # type: ignore


@translate.register(ops.Clip)
Expand All @@ -1243,19 +1263,19 @@ def _clip(

if lower is not None and upper is not None:
return translate(
(arg >= lower).ifelse((arg <= upper).ifelse(arg, upper), lower),
(arg >= lower).ifelse((arg <= upper).ifelse(arg, upper), lower), # type: ignore
compiler=compiler,
**kwargs,
)
elif lower is not None:
return translate(
(arg >= lower).ifelse(arg, lower),
(arg >= lower).ifelse(arg, lower), # type: ignore
compiler=compiler,
**kwargs,
)
elif upper is not None:
return translate(
(arg <= upper).ifelse(arg, upper),
(arg <= upper).ifelse(arg, upper), # type: ignore
compiler=compiler,
**kwargs,
)
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ substrait = ">=0.2.1"

[tool.poetry.group.dev.dependencies]
black = ">=23.0.0"
duckdb = ">=0.4.0"
duckdb = ">=0.8.1"
duckdb-engine = ">=0.5"
ipython = ">=8.2.0"
ruff = ">=0.0.252"
Expand Down Expand Up @@ -137,6 +137,9 @@ filterwarnings = [
"ignore: Deprecated API features detected:sqlalchemy.exc.RemovedIn20Warning",
# ignore struct pairs deprecation while we still support 4.0
"ignore: `Struct.pairs` is deprecated:FutureWarning",
# ignore output_dtype deprecation until version 7.0 is minimum supported
"ignore:`Value.output_dtype` is deprecated:FutureWarning",

]
markers = ["no_decompile"]
norecursedirs = ["site-packages", "dist-packages", ".direnv"]
Expand Down

0 comments on commit c6d77be

Please sign in to comment.