Skip to content

Commit

Permalink
Refactor FunctionsCatalog and improve functions lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
amol- committed May 8, 2024
1 parent 145cb4a commit a727475
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 96 deletions.
29 changes: 11 additions & 18 deletions src/substrait/sql/extended_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,30 +194,20 @@ def _parse_function_invokation(
invokation expression itself.
"""
arguments = [argument_parsed_expr] + list(additional_arguments)
signature = self._functions_catalog.signature(
signature = self._functions_catalog.make_signature(
function_name, proto_argtypes=[arg.type for arg in arguments]
)

try:
function_anchor = self._functions_catalog.function_anchor(signature)
except KeyError:
# No function found with the exact types, try any1_any1 version
# TODO: What about cases like i32_any1? What about any instead of any1?
# TODO: What about optional arguments? IE: "i32_i32?"
signature = f"{function_name}:{'_'.join(['any1']*len(arguments))}"
function_anchor = self._functions_catalog.function_anchor(signature)

function_return_type = self._functions_catalog.function_return_type(signature)
if function_return_type is None:
print("No return type for", signature)
# TODO: Is this the right way to handle this?
function_return_type = left_type
registered_function = self._functions_catalog.lookup_function(signature)
if registered_function is None:
raise KeyError(f"Function not found: {signature}")

return (
signature,
function_return_type,
registered_function.signature,
registered_function.return_type,
proto.Expression(
scalar_function=proto.Expression.ScalarFunction(
function_reference=function_anchor,
function_reference=registered_function.function_anchor,
arguments=[
proto.FunctionArgument(value=arg.expression)
for arg in arguments
Expand Down Expand Up @@ -255,3 +245,6 @@ def duplicate(
expression or self.expression,
invoked_functions or self.invoked_functions,
)

def __repr__(self):
return f"<ParsedSubstraitExpression {self.output_name} {self.type}>"
257 changes: 180 additions & 77 deletions src/substrait/sql/functions_catalog.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,80 @@
import os
import pathlib
from collections.abc import Iterable

import yaml

from substrait import proto
from substrait.gen.proto.type_pb2 import Type as SubstraitType
from substrait.gen.proto.extensions.extensions_pb2 import (
SimpleExtensionURI,
SimpleExtensionDeclaration,
)


class RegisteredSubstraitFunction:
"""A Substrait function loaded from an extension file.
The FunctionsCatalog will keep a collection of RegisteredSubstraitFunction
and will use them to generate the necessary extension URIs and extensions.
"""

def __init__(self, signature: str, function_anchor: int | None, impl: dict):
self.signature = signature
self.function_anchor = function_anchor
self.variadic = impl.get("variadic", False)

if "return" in impl:
self.return_type = self._type_from_name(impl["return"])
else:
# We do always need a return type
# to know which type to propagate up to the invoker
_, argtypes = FunctionsCatalog.parse_signature(signature)
# TODO: Is this the right way to handle this?
self.return_type = self._type_from_name(argtypes[0])

@property
def name(self) -> str:
name, _ = FunctionsCatalog.parse_signature(self.signature)
return name

@property
def arguments(self) -> list[str]:
_, argtypes = FunctionsCatalog.parse_signature(self.signature)
return argtypes

@property
def arguments_type(self) -> list[SubstraitType | None]:
return [self._type_from_name(arg) for arg in self.arguments]

def _type_from_name(self, typename: str) -> SubstraitType | None:
nullable = False
if typename.endswith("?"):
nullable = True

typename = typename.strip("?")
if typename in ("any", "any1"):
return None

if typename == "boolean":
# For some reason boolean is an exception to the naming convention
typename = "bool"

try:
type_descriptor = SubstraitType.DESCRIPTOR.fields_by_name[
typename
].message_type
except KeyError:
# TODO: improve resolution of complext type like LIST?<any>
print("Unsupported type", typename)
return None

type_class = getattr(SubstraitType, type_descriptor.name)
nullability = (
SubstraitType.Nullability.NULLABILITY_REQUIRED
if not nullable
else SubstraitType.Nullability.NULLABILITY_NULLABLE
)
return SubstraitType(**{typename: type_class(nullability=nullability)})


class FunctionsCatalog:
Expand Down Expand Up @@ -32,20 +104,21 @@ class FunctionsCatalog:
)

def __init__(self):
self._registered_extensions = {}
self._substrait_extension_uris = {}
self._substrait_extension_functions = {}
self._functions = {}
self._functions_return_type = {}

def load_standard_extensions(self, dirpath):
def load_standard_extensions(self, dirpath: str | os.PathLike):
"""Load all standard substrait extensions from the target directory."""
for ext in self.STANDARD_EXTENSIONS:
self.load(dirpath, ext)

def load(self, dirpath, filename):
def load(self, dirpath: str | os.PathLike, filename: str):
"""Load an extension from a YAML file in a target directory."""
with open(pathlib.Path(dirpath) / filename.strip("/")) as f:
sections = yaml.safe_load(f)

loaded_functions = set()
functions_return_type = {}
loaded_functions = {}
for functions in sections.values():
for function in functions:
function_name = function["name"]
Expand All @@ -56,100 +129,80 @@ def load(self, dirpath, filename):
t.get("value", "unknown").strip("?")
for t in impl.get("args", [])
]
if impl.get("variadic", False):
# TODO: Variadic functions.
argtypes *= 2

if not argtypes:
signature = function_name
else:
signature = f"{function_name}:{'_'.join(argtypes)}"
loaded_functions.add(signature)
print("Loaded function", signature)
functions_return_type[signature] = self._type_from_name(
impl["return"]
loaded_functions[signature] = RegisteredSubstraitFunction(
signature, None, impl
)

self._register_extensions(filename, loaded_functions, functions_return_type)
self._register_extensions(filename, loaded_functions)

def _register_extensions(
self, extension_uri, loaded_functions, functions_return_type
self,
extension_uri: str,
loaded_functions: dict[str, RegisteredSubstraitFunction],
):
if extension_uri not in self._registered_extensions:
ext_anchor_id = len(self._registered_extensions) + 1
self._registered_extensions[extension_uri] = proto.SimpleExtensionURI(
if extension_uri not in self._substrait_extension_uris:
ext_anchor_id = len(self._substrait_extension_uris) + 1
self._substrait_extension_uris[extension_uri] = SimpleExtensionURI(
extension_uri_anchor=ext_anchor_id, uri=extension_uri
)

for function in loaded_functions:
if function in self._functions:
for signature, registered_function in loaded_functions.items():
if signature in self._substrait_extension_functions:
extensions_by_anchor = self.extension_uris_by_anchor
existing_function = self._functions[function]
existing_function = self._substrait_extension_functions[signature]
function_extension = extensions_by_anchor[
existing_function.extension_uri_reference
].uri
raise ValueError(
f"Duplicate function definition: {existing_function.name} from {extension_uri}, already loaded from {function_extension}"
)
extension_anchor = self._registered_extensions[
extension_anchor = self._substrait_extension_uris[
extension_uri
].extension_uri_anchor
function_anchor = len(self._functions) + 1
self._functions[function] = (
proto.SimpleExtensionDeclaration.ExtensionFunction(
function_anchor = len(self._substrait_extension_functions) + 1
self._substrait_extension_functions[signature] = (
SimpleExtensionDeclaration.ExtensionFunction(
extension_uri_reference=extension_anchor,
name=function,
name=signature,
function_anchor=function_anchor,
)
)
self._functions_return_type[function] = functions_return_type[function]

def _type_from_name(self, typename):
nullable = False
if typename.endswith("?"):
nullable = True

typename = typename.strip("?")
if typename in ("any", "any1"):
return None

if typename == "boolean":
# For some reason boolean is an exception to the naming convention
typename = "bool"

try:
type_descriptor = proto.Type.DESCRIPTOR.fields_by_name[
typename
].message_type
except KeyError:
# TODO: improve resolution of complext type like LIST?<any>
print("Unsupported type", typename)
return None

type_class = getattr(proto.Type, type_descriptor.name)
nullability = (
proto.Type.Nullability.NULLABILITY_REQUIRED
if not nullable
else proto.Type.Nullability.NULLABILITY_NULLABLE
)
return proto.Type(**{typename: type_class(nullability=nullability)})
registered_function.function_anchor = function_anchor
self._functions.setdefault(registered_function.name, []).append(
registered_function
)

@property
def extension_uris_by_anchor(self):
def extension_uris_by_anchor(self) -> dict[int, SimpleExtensionURI]:
return {
ext.extension_uri_anchor: ext
for ext in self._registered_extensions.values()
for ext in self._substrait_extension_uris.values()
}

@property
def extension_uris(self):
return list(self._registered_extensions.values())
def extension_uris(self) -> list[SimpleExtensionURI]:
return list(self._substrait_extension_uris.values())

@property
def extensions(self):
return list(self._functions.values())
def extensions_functions(
self,
) -> list[SimpleExtensionDeclaration.ExtensionFunction]:
return list(self._substrait_extension_functions.values())

@classmethod
def make_signature(
cls, function_name: str, proto_argtypes: Iterable[SubstraitType]
):
"""Create a function signature from a function name and substrait types.
The signature is generated according to Function Signature Compound Names
as described in the Substrait documentation.
"""

def signature(self, function_name, proto_argtypes):
def _normalize_arg_types(argtypes):
for argtype in argtypes:
kind = argtype.WhichOneof("kind")
Expand All @@ -160,23 +213,73 @@ def _normalize_arg_types(argtypes):

return f"{function_name}:{'_'.join(_normalize_arg_types(proto_argtypes))}"

def function_anchor(self, function):
return self._functions[function].function_anchor
@classmethod
def parse_signature(cls, signature: str) -> tuple[str, list[str]]:
"""Parse a function signature and returns name and type names"""
try:
function_name, signature_args = signature.split(":")
except ValueError:
function_name = signature
argtypes = []
else:
argtypes = signature_args.split("_")
return function_name, argtypes

def function_return_type(self, function):
return self._functions_return_type[function]
def extensions_for_functions(
self, function_signatures: Iterable[str]
) -> tuple[list[SimpleExtensionURI], list[SimpleExtensionDeclaration]]:
"""Given a set of function signatures, return the necessary extensions.
def extensions_for_functions(self, functions):
The function will return the URIs of the extensions and the extension
that have to be declared in the plan to use the functions.
"""
uris_anchors = set()
extensions = []
for f in functions:
ext = self._functions[f]
if not ext.extension_uri_reference:
# Built-in function
continue
for f in function_signatures:
ext = self._substrait_extension_functions[f]
uris_anchors.add(ext.extension_uri_reference)
extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext))
extensions.append(SimpleExtensionDeclaration(extension_function=ext))

uris_by_anchor = self.extension_uris_by_anchor
extension_uris = [uris_by_anchor[uri_anchor] for uri_anchor in uris_anchors]
return extension_uris, extensions

def lookup_function(self, signature: str) -> RegisteredSubstraitFunction | None:
"""Given the signature of a function invocation, return the matching function."""
function_name, invocation_argtypes = self.parse_signature(signature)

functions = self._functions.get(function_name)
if not functions:
# No function with such a name at all.
return None

is_variadic = functions[0].variadic
if is_variadic:
# If it's variadic we care about only the first parameter.
invocation_argtypes = invocation_argtypes[:1]

found_function = None
for function in functions:
accepted_function_arguments = function.arguments
for argidx, argtype in enumerate(invocation_argtypes):
try:
accepted_argument = accepted_function_arguments[argidx]
except IndexError:
# More arguments than available were provided
break
if accepted_argument != argtype and accepted_argument not in (
"any",
"any1",
):
break
else:
if argidx < len(accepted_function_arguments) - 1:
# Not enough arguments were provided
remainder = accepted_function_arguments[argidx + 1 :]
if all(arg.endswith("?") for arg in remainder):
# All remaining arguments are optional
found_function = function
else:
found_function = function

return found_function
4 changes: 3 additions & 1 deletion src/substrait/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def __getitem__(self, argument):
if isinstance(argument, dispatch_cls):
return func
else:
raise ValueError(f"Unsupported SQL Node type: {cls}")
raise ValueError(
f"Unsupported SQL Node type: {argument.__class__.__name__} -> {argument}"
)

def __call__(self, obj, dispatch_argument, *args, **kwargs):
return self[dispatch_argument](obj, dispatch_argument, *args, **kwargs)

0 comments on commit a727475

Please sign in to comment.