Skip to content

Commit

Permalink
fix(uri): allow overriding extension uri when registering extension file
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth committed May 17, 2023
1 parent 01b96c9 commit 59d5b29
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
21 changes: 18 additions & 3 deletions ibis_substrait/compiler/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,23 @@ def _parse_func(entry: Mapping[str, Any]) -> Iterator[FunctionEntry]:
yield sf


def register_extension_yaml(fname: str | Path, prefix: str | None = None) -> None:
"""Add a substrait extension YAML file to the ibis substrait compiler."""
def register_extension_yaml(
fname: str | Path, prefix: str | None = None, uri: str | None = None
) -> None:
"""Add a substrait extension YAML file to the ibis substrait compiler.
Parameters
----------
fname
The filename of the extension yaml to register.
prefix
Custom prefix to use when constructing Substrait extension URI
uri
A custom URI to use for all functions defined within `fname`.
If passed, this value overrides `prefix`.
"""
fname = Path(fname)
with open(fname) as f: # type: ignore
extension_definitions = yaml.safe_load(f)
Expand All @@ -195,7 +210,7 @@ def register_extension_yaml(fname: str | Path, prefix: str | None = None) -> Non
for named_functions in extension_definitions.values():
for function in named_functions:
for func in _parse_func(function):
func.uri = f"{prefix}/{fname.name}"
func.uri = uri or f"{prefix}/{fname.name}"
_extension_mapping[function["name"]][tuple(func.inputs)] = func


Expand Down
30 changes: 30 additions & 0 deletions ibis_substrait/tests/compiler/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,33 @@ def sub1(col, ctx=None):
assert len(plan.extensions) == 2
assert plan.extensions[0].extension_function.name == "add1"
assert plan.extensions[1].extension_function.name == "sub1"


def test_extension_register_uri_override(tmp_path):
from ibis_substrait.compiler.mapping import (
_extension_mapping,
register_extension_yaml,
)

sample_yaml = """scalar_functions:
-
name: "anotheradd"
impls:
- args:
- name: x
value: a
- name: y
value: b
return: c"""

yaml_file = tmp_path / "foo.yaml"
yaml_file.write_text(sample_yaml)

register_extension_yaml(yaml_file, uri="orkbork")

assert _extension_mapping["anotheradd"]
assert _extension_mapping["anotheradd"][("a", "b")].uri == "orkbork"

register_extension_yaml(yaml_file, prefix="orkbork")
assert _extension_mapping["anotheradd"]
assert _extension_mapping["anotheradd"][("a", "b")].uri == "orkbork/foo.yaml"

0 comments on commit 59d5b29

Please sign in to comment.