diff --git a/ibis_substrait/compiler/translate.py b/ibis_substrait/compiler/translate.py index 450c2f1b..52ed38e9 100644 --- a/ibis_substrait/compiler/translate.py +++ b/ibis_substrait/compiler/translate.py @@ -37,6 +37,14 @@ IBIS_GTE_5 = version.parse(ibis.__version__) >= version.parse("5.0.0") IBIS_GTE_6 = version.parse(ibis.__version__) >= version.parse("6.0.0") +IBIS_GTE_7 = version.parse(ibis.__version__) >= version.parse("7.0.0") + + +# When an op gets renamed between major versions, we can assign the old name to this +# DummyOp so that we don't get Attribute errors when the new version goes looking for it +class DummyOp(ops.Value): + pass + try: from typing import TypeAlias @@ -47,6 +55,11 @@ T = TypeVar("T") +if IBIS_GTE_7: + ops.Where = ops.IfElse # type: ignore + ops.Contains = DummyOp # type: ignore + + def _nullability(dtype: dt.DataType) -> stt.Type.Nullability.V: return ( stt.Type.Nullability.NULLABILITY_NULLABLE @@ -474,9 +487,9 @@ def _window_boundary( # type: ignore ) -> stalg.Expression.WindowFunction.Bound: # new window boundary class in Ibis 5.x if boundary.preceding: - return translate_preceding(boundary.value.value) + return translate_preceding(boundary.value.value) # type: ignore else: - return translate_following(boundary.value.value) + return translate_following(boundary.value.value) # type: ignore def _translate_window_bounds( @@ -808,7 +821,7 @@ def selection( ) # filter if op.predicates: - predicates = [pred.to_expr() for pred in op.predicates] + predicates = [pred.to_expr() for pred in op.predicates] # type: ignore relation = stalg.Rel( filter=stalg.FilterRel( input=relation, @@ -902,7 +915,8 @@ def _get_selections(op: ops.Selection) -> Sequence[ir.Column]: # projection / emit selections = [ col - for sel in (x.to_expr() for x in op.selections) # map ops to exprs + # map ops to exprs + for sel in (x.to_expr() for x in op.selections) # type: ignore for col in ( map(sel.__getitem__, sel.columns) if isinstance(sel, ir.TableExpr) # type: ignore @@ -1010,9 +1024,9 @@ def translate_set_op_type(op: ops.SetOp) -> stalg.SetRel.SetOp.V: ) -@translate_set_op_type.register(ops.Union) +@translate_set_op_type.register(ops.Union) # type: ignore def set_op_type_union(op: ops.Union) -> stalg.SetRel.SetOp.V: - if op.distinct: + if op.distinct: # type: ignore return stalg.SetRel.SetOp.SET_OP_UNION_DISTINCT return stalg.SetRel.SetOp.SET_OP_UNION_ALL @@ -1055,8 +1069,8 @@ def aggregation( if op.having: raise NotImplementedError("`having` not yet implemented") - table = op.table.to_expr() - predicates = [pred.to_expr() for pred in op.predicates] + table = op.table.to_expr() # type: ignore + predicates = [pred.to_expr() for pred in op.predicates] # type: ignore input = translate( table.filter(predicates) if predicates else table, compiler=compiler, @@ -1116,9 +1130,9 @@ def _simple_searched_case( return stalg.Expression(if_then=stalg.Expression.IfThen(ifs=_ifs, **_else)) -@translate.register(ops.Where) +@translate.register(ops.Where) # type: ignore def _where( - op: ops.Where, + op: ops.Where, # type: ignore *, compiler: SubstraitCompiler, **kwargs: Any, @@ -1136,9 +1150,9 @@ def _where( return stalg.Expression(if_then=stalg.Expression.IfThen(ifs=_ifs, **_else)) -@translate.register(ops.Contains) +@translate.register(ops.Contains) # type: ignore def _contains( - op: ops.Contains, + op: ops.Contains, # type: ignore *, compiler: SubstraitCompiler, **kwargs: Any, @@ -1158,6 +1172,12 @@ def _contains( ) +if IBIS_GTE_7: + # Contains was decomposed into two separate ops in Ibis 7.x + translate.register(ops.InColumn)(_contains) + translate.register(ops.InValues)(_contains) + + @translate.register(ops.Cast) def _cast( op: ops.Cast, @@ -1229,7 +1249,7 @@ def _floordivide( **kwargs: Any, ) -> stalg.Expression: left, right = op.left, op.right - return translate((left / right).floor(), compiler=compiler, **kwargs) + return translate((left / right).floor(), compiler=compiler, **kwargs) # type: ignore @translate.register(ops.Clip) @@ -1243,19 +1263,19 @@ def _clip( if lower is not None and upper is not None: return translate( - (arg >= lower).ifelse((arg <= upper).ifelse(arg, upper), lower), + (arg >= lower).ifelse((arg <= upper).ifelse(arg, upper), lower), # type: ignore compiler=compiler, **kwargs, ) elif lower is not None: return translate( - (arg >= lower).ifelse(arg, lower), + (arg >= lower).ifelse(arg, lower), # type: ignore compiler=compiler, **kwargs, ) elif upper is not None: return translate( - (arg <= upper).ifelse(arg, upper), + (arg <= upper).ifelse(arg, upper), # type: ignore compiler=compiler, **kwargs, ) diff --git a/pyproject.toml b/pyproject.toml index 1bde7889..042f72c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ substrait = ">=0.2.1" [tool.poetry.group.dev.dependencies] black = ">=23.0.0" -duckdb = ">=0.4.0" +duckdb = ">=0.8.1" duckdb-engine = ">=0.5" ipython = ">=8.2.0" ruff = ">=0.0.252" @@ -137,6 +137,9 @@ filterwarnings = [ "ignore: Deprecated API features detected:sqlalchemy.exc.RemovedIn20Warning", # ignore struct pairs deprecation while we still support 4.0 "ignore: `Struct.pairs` is deprecated:FutureWarning", + # ignore output_dtype deprecation until version 7.0 is minimum supported + "ignore:`Value.output_dtype` is deprecated:FutureWarning", + ] markers = ["no_decompile"] norecursedirs = ["site-packages", "dist-packages", ".direnv"]