Skip to content

Commit

Permalink
feat: add support for Ibis 6.x
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth committed Aug 2, 2023
1 parent 5c6f4fa commit fcfc595
Show file tree
Hide file tree
Showing 5 changed files with 432 additions and 43 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ jobs:
ibis-version:
- "4.1"
- "5.1"
- "6.0"
include:
- os: windows-latest
python-version: "3.11"
ibis-version: "5.1"

ibis-version: "6.0"

steps:
- uses: actions/checkout@v3
Expand Down Expand Up @@ -133,6 +133,7 @@ jobs:
ibis-version:
- "4.1"
- "5.1"
- "6.0"

steps:
- uses: actions/checkout@v3
Expand Down
54 changes: 41 additions & 13 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,49 @@ def _timestamp(dtype: dt.Timestamp) -> stt.Type:
return stt.Type(timestamp=stt.Type.Timestamp(nullability=nullability))


@translate.register
def _interval(dtype: dt.Interval) -> stt.Type:
unit = dtype.unit
nullability = _nullability(dtype)
if IBIS_GTE_6:

@translate.register
def _interval(dtype: dt.Interval) -> stt.Type:
unit = dtype.unit.name
nullability = _nullability(dtype)

if unit == "YEAR":
return stt.Type(
interval_year=stt.Type.IntervalYear(nullability=nullability)
)
elif unit == "MONTH":
return stt.Type(
interval_year=stt.Type.IntervalYear(nullability=nullability)
)
elif unit == "DAY":
return stt.Type(interval_day=stt.Type.IntervalDay(nullability=nullability))
elif unit == "SECOND":
return stt.Type(interval_day=stt.Type.IntervalDay(nullability=nullability))

if unit == "Y":
return stt.Type(interval_year=stt.Type.IntervalYear(nullability=nullability))
elif unit == "M":
return stt.Type(interval_year=stt.Type.IntervalYear(nullability=nullability))
elif unit == "D":
return stt.Type(interval_day=stt.Type.IntervalDay(nullability=nullability))
elif unit == "s":
return stt.Type(interval_day=stt.Type.IntervalDay(nullability=nullability))
raise ValueError(f"unsupported substrait unit: {unit!r}")

else:

@translate.register
def _interval(dtype: dt.Interval) -> stt.Type:
unit = dtype.unit
nullability = _nullability(dtype)

if unit == "Y":
return stt.Type(
interval_year=stt.Type.IntervalYear(nullability=nullability)
)
elif unit == "M":
return stt.Type(
interval_year=stt.Type.IntervalYear(nullability=nullability)
)
elif unit == "D":
return stt.Type(interval_day=stt.Type.IntervalDay(nullability=nullability))
elif unit == "s":
return stt.Type(interval_day=stt.Type.IntervalDay(nullability=nullability))

raise ValueError(f"unsupported substrait unit: {unit!r}")
raise ValueError(f"unsupported substrait unit: {unit!r}")


@translate.register
Expand Down
5 changes: 4 additions & 1 deletion ibis_substrait/tests/compiler/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,10 @@ def test_extension_boolean(compiler, left, right, bin_op, exp_func, exp_uri):


def test_extension_udf_compile(compiler):
from ibis.udf.vectorized import elementwise
try:
from ibis.udf.vectorized import elementwise
except ImportError:
from ibis.legacy.udf.vectorized import elementwise

pc = None

Expand Down
15 changes: 13 additions & 2 deletions ibis_substrait/tests/integration/test_pyarrow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import ibis
import pytest
from ibis.udf.vectorized import elementwise
from packaging.version import parse as vparse

from ibis_substrait.compiler.core import SubstraitCompiler
Expand All @@ -9,6 +8,11 @@
pc = pytest.importorskip("pyarrow.compute")
pa_substrait = pytest.importorskip("pyarrow.substrait")

try:
from ibis.udf.vectorized import elementwise
except ImportError:
from ibis.legacy.udf.vectorized import elementwise


arrow12 = pytest.mark.skipif(
vparse(pa.__version__) <= vparse("11.0.0"),
Expand Down Expand Up @@ -103,7 +107,14 @@ def register_pyarrow_udf(udf, registry=None):
"""
import inspect

from ibis.backends.pyarrow.datatypes import to_pyarrow_type
try:
# Ibis 6.x
from ibis.formats.pyarrow import PyArrowType

to_pyarrow_type = PyArrowType.from_ibis
except ImportError:
# Ibis 4.x, 5.x
from ibis.backends.pyarrow.datatypes import to_pyarrow_type

if registry is None:
registry = pc.function_registry()
Expand Down
Loading

0 comments on commit fcfc595

Please sign in to comment.