Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

example: self-correcting loop for RAG #6420

Merged
merged 25 commits into from
Dec 20, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
move functions
ZanSara committed Dec 18, 2023
commit bcb6a8460f259c20084d0b2dc92a97b5ec60ce6d
119 changes: 3 additions & 116 deletions haystack/components/routers/conditional_router.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import importlib
import inspect
import logging
import sys
from typing import List, Dict, Any, Set, get_origin
from typing import List, Dict, Any, Set

from jinja2 import meta, Environment, TemplateSyntaxError
from jinja2.nativetypes import NativeEnvironment

from haystack import component, default_from_dict, default_to_dict, DeserializationError
from haystack import component, default_from_dict, default_to_dict
from haystack.utils.type_serialization import serialize_type, deserialize_type

logger = logging.getLogger(__name__)

@@ -20,117 +18,6 @@ class RouteConditionException(Exception):
"""Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter."""


def serialize_type(target: Any) -> str:
"""
Serializes a type or an instance to its string representation, including the module name.
This function handles types, instances of types, and special typing objects.
It assumes that non-typing objects will have a '__name__' attribute and raises
an error if a type cannot be serialized.
:param target: The object to serialize, can be an instance or a type.
:type target: Any
:return: The string representation of the type.
:raises ValueError: If the type cannot be serialized.
"""
# If the target is a string and contains a dot, treat it as an already serialized type
if isinstance(target, str) and "." in target:
return target

# Determine if the target is a type or an instance of a typing object
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
type_obj = target if is_type_or_typing else type(target)
module = inspect.getmodule(type_obj)
type_obj_repr = repr(type_obj)

if type_obj_repr.startswith("typing."):
# e.g., typing.List[int] -> List[int], we'll add the module below
type_name = type_obj_repr.split(".", 1)[1]
elif hasattr(type_obj, "__name__"):
type_name = type_obj.__name__
else:
# If type cannot be serialized, raise an error
raise ValueError(f"Could not serialize type: {type_obj_repr}")

# Construct the full path with module name if available
if module and hasattr(module, "__name__"):
if module.__name__ == "builtins":
# omit the module name for builtins, it just clutters the output
# e.g. instead of 'builtins.str', we'll just return 'str'
full_path = type_name
else:
full_path = f"{module.__name__}.{type_name}"
else:
full_path = type_name

return full_path


def deserialize_type(type_str: str) -> Any:
"""
Deserializes a type given its full import path as a string, including nested generic types.
This function will dynamically import the module if it's not already imported
and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'.
:param type_str: The string representation of the type's full import path.
:return: The deserialized type object.
:raises DeserializationError: If the type cannot be deserialized due to missing module or type.
"""

def parse_generic_args(args_str):
args = []
bracket_count = 0
current_arg = ""

for char in args_str:
if char == "[":
bracket_count += 1
elif char == "]":
bracket_count -= 1

if char == "," and bracket_count == 0:
args.append(current_arg.strip())
current_arg = ""
else:
current_arg += char

if current_arg:
args.append(current_arg.strip())

return args

if "[" in type_str and type_str.endswith("]"):
# Handle generics
main_type_str, generics_str = type_str.split("[", 1)
generics_str = generics_str[:-1]

main_type = deserialize_type(main_type_str)
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))

# Reconstruct
return main_type[generic_args]

else:
# Handle non-generics
parts = type_str.split(".")
module_name = ".".join(parts[:-1]) or "builtins"
type_name = parts[-1]

module = sys.modules.get(module_name)
if not module:
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise DeserializationError(f"Could not import the module: {module_name}") from e

deserialized_type = getattr(module, type_name, None)
if not deserialized_type:
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")

return deserialized_type


@component
class ConditionalRouter:
"""
117 changes: 117 additions & 0 deletions haystack/utils/type_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import importlib
import inspect
import sys
from typing import Any, get_origin

from haystack import DeserializationError


def serialize_type(target: Any) -> str:
"""
Serializes a type or an instance to its string representation, including the module name.
This function handles types, instances of types, and special typing objects.
It assumes that non-typing objects will have a '__name__' attribute and raises
an error if a type cannot be serialized.
:param target: The object to serialize, can be an instance or a type.
:type target: Any
:return: The string representation of the type.
:raises ValueError: If the type cannot be serialized.
"""
# If the target is a string and contains a dot, treat it as an already serialized type
if isinstance(target, str) and "." in target:
return target

# Determine if the target is a type or an instance of a typing object
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
type_obj = target if is_type_or_typing else type(target)
module = inspect.getmodule(type_obj)
type_obj_repr = repr(type_obj)

if type_obj_repr.startswith("typing."):
# e.g., typing.List[int] -> List[int], we'll add the module below
type_name = type_obj_repr.split(".", 1)[1]
elif hasattr(type_obj, "__name__"):
type_name = type_obj.__name__
else:
# If type cannot be serialized, raise an error
raise ValueError(f"Could not serialize type: {type_obj_repr}")

# Construct the full path with module name if available
if module and hasattr(module, "__name__"):
if module.__name__ == "builtins":
# omit the module name for builtins, it just clutters the output
# e.g. instead of 'builtins.str', we'll just return 'str'
full_path = type_name
else:
full_path = f"{module.__name__}.{type_name}"
else:
full_path = type_name

return full_path


def deserialize_type(type_str: str) -> Any:
"""
Deserializes a type given its full import path as a string, including nested generic types.
This function will dynamically import the module if it's not already imported
and then retrieve the type object from it. It also handles nested generic types like 'typing.List[typing.Dict[int, str]]'.
:param type_str: The string representation of the type's full import path.
:return: The deserialized type object.
:raises DeserializationError: If the type cannot be deserialized due to missing module or type.
"""

def parse_generic_args(args_str):
args = []
bracket_count = 0
current_arg = ""

for char in args_str:
if char == "[":
bracket_count += 1
elif char == "]":
bracket_count -= 1

if char == "," and bracket_count == 0:
args.append(current_arg.strip())
current_arg = ""
else:
current_arg += char

if current_arg:
args.append(current_arg.strip())

return args

if "[" in type_str and type_str.endswith("]"):
# Handle generics
main_type_str, generics_str = type_str.split("[", 1)
generics_str = generics_str[:-1]

main_type = deserialize_type(main_type_str)
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))

# Reconstruct
return main_type[generic_args]

else:
# Handle non-generics
parts = type_str.split(".")
module_name = ".".join(parts[:-1]) or "builtins"
type_name = parts[-1]

module = sys.modules.get(module_name)
if not module:
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise DeserializationError(f"Could not import the module: {module_name}") from e

deserialized_type = getattr(module, type_name, None)
if not deserialized_type:
raise DeserializationError(f"Could not locate the type: {type_name} in the module: {module_name}")

return deserialized_type