Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom strong type connection in script tool (Simplified) #733

Merged
merged 4 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/promptflow/promptflow/_core/tool_meta_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def collect_tool_methods_in_module(m):
return tools


def _parse_tool_from_function(f):
def _parse_tool_from_function(f, gen_custom_type_conn=False):
if hasattr(f, "__tool") and isinstance(f.__tool, Tool):
return f.__tool
if hasattr(f, "__original_function"):
f = f.__original_function
try:
inputs, _, _ = function_to_interface(f)
inputs, _, _ = function_to_interface(f, gen_custom_type_conn=gen_custom_type_conn)
except Exception as e:
error_type_and_message = f"({e.__class__.__name__}) {e}"
raise BadFunctionInterface(
Expand Down
7 changes: 4 additions & 3 deletions src/promptflow/promptflow/_core/tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import inspect
import logging
import traceback
import types
from functools import partial
from pathlib import Path
from typing import Callable, List, Mapping, Optional, Tuple, Union
Expand Down Expand Up @@ -347,7 +348,7 @@ def load_tool_for_node(self, node: Node) -> Tool:
if node.source.type == ToolSourceType.Package:
return self.load_tool_for_package_node(node)
elif node.source.type == ToolSourceType.Code:
_, tool = self.load_tool_for_script_node(node)
_, _, tool = self.load_tool_for_script_node(node)
return tool
raise NotImplementedError(f"Tool source type {node.source.type} for python tool is not supported yet.")
elif node.type == ToolType.CUSTOM_LLM:
Expand All @@ -366,15 +367,15 @@ def load_tool_for_package_node(self, node: Node) -> Tool:
target=ErrorTarget.EXECUTOR,
)

def load_tool_for_script_node(self, node: Node) -> Tuple[Callable, Tool]:
def load_tool_for_script_node(self, node: Node) -> Tuple[types.ModuleType, Callable, Tool]:
if node.source.path is None:
raise UserErrorException(f"Node {node.name} does not have source path defined.")
path = node.source.path
m = load_python_module_from_file(self._working_dir / path)
if m is None:
raise CustomToolSourceLoadError(f"Cannot load module from {path}.")
f = collect_tool_function_in_module(m)
return f, _parse_tool_from_function(f)
return m, f, _parse_tool_from_function(f, gen_custom_type_conn=True)

def load_tool_for_llm_node(self, node: Node) -> Tool:
api_name = f"{node.provider}.{node.api}"
Expand Down
29 changes: 22 additions & 7 deletions src/promptflow/promptflow/_sdk/entities/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def _convert_to_custom(self):
return custom_connection

@classmethod
def _get_custom_keys(cls, data):
def _get_custom_keys(cls, data: Dict):
# The data could be either from yaml or from DB.
# If from yaml, 'custom_type' and 'module' are outside the configs of data.
# If from DB, 'custom_type' and 'module' are within the configs of data.
Expand Down Expand Up @@ -933,13 +933,28 @@ def _is_custom_strong_type(self):
and self.configs[CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY]
)

def _convert_to_custom_strong_type(self):
module_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY)
custom_type_class_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY)
import importlib
def _convert_to_custom_strong_type(self, to_class) -> CustomStrongTypeConnection:
# There are two scenarios to convert a custom connection to custom strong type connection:
# 1. The connection is created from a custom strong type connection template file.
# Custom type and module name are present in the configs.
# 2. The connection is created through SDK PFClient or a custom connection template file.
# Custom type and module name are not present in the configs. Module and class must be passed for conversion.
16oeahr marked this conversation as resolved.
Show resolved Hide resolved
if to_class and not isinstance(to_class, type):
raise TypeError(f"The converted type {to_class} must be a class type.")

module = importlib.import_module(module_name)
custom_defined_connection_class = getattr(module, custom_type_class_name)
if not to_class:
module_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY)
import importlib

module = importlib.import_module(module_name)
custom_type_class_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY)
custom_defined_connection_class = getattr(module, custom_type_class_name)

if to_class and issubclass(to_class, CustomConnection):
# No need to convert.
return self

custom_defined_connection_class = to_class or custom_defined_connection_class
connection_instance = custom_defined_connection_class(configs=self.configs, secrets=self.secrets)

return connection_instance
Expand Down
50 changes: 44 additions & 6 deletions src/promptflow/promptflow/_utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ def resolve_annotation(anno) -> Union[str, list]:
return args[0] if len(args) == 1 else args


def param_to_definition(param) -> (InputDefinition, bool):
def param_to_definition(param, gen_custom_type_conn=False) -> (InputDefinition, bool):
default_value = param.default
# Get value type and enum from annotation
value_type = resolve_annotation(param.annotation)
enum = None
custom_type_conn = None
# Get value type and enum from default if no annotation
if default_value is not inspect.Parameter.empty and value_type == inspect.Parameter.empty:
value_type = default_value.__class__ if isinstance(default_value, Enum) else type(default_value)
Expand All @@ -51,20 +52,57 @@ def param_to_definition(param) -> (InputDefinition, bool):
value_type = str
is_connection = False
if ConnectionType.is_connection_value(value_type):
typ = [value_type.__name__]
if ConnectionType.is_custom_strong_type(value_type):
typ = ["CustomConnection"]
custom_type_conn = [value_type.__name__]
else:
typ = [value_type.__name__]
is_connection = True
elif isinstance(value_type, list):
if not all(ConnectionType.is_connection_value(t) for t in value_type):
typ = [ValueType.OBJECT]
else:
typ = [t.__name__ for t in value_type]
custom_connection_added = False
typ = []
custom_type_conn = []
for t in value_type:
16oeahr marked this conversation as resolved.
Show resolved Hide resolved
# Add 'CustomConnection' to typ list when custom strong type connection exists. Collect all custom types
if ConnectionType.is_custom_strong_type(t):
if not custom_connection_added:
custom_connection_added = True
typ.append("CustomConnection")
custom_type_conn.append(t.__name__)
else:
if t.__name__ != "CustomConnection":
typ.append(t.__name__)
elif not custom_connection_added:
custom_connection_added = True
typ.append(t.__name__)
is_connection = True
else:
typ = [ValueType.from_type(value_type)]
return InputDefinition(type=typ, default=value_to_str(default_value), description=None, enum=enum), is_connection

# 1. Do not generate custom type when generating flow.tools.json for script tool.
# Extension would show custom type if it exists. While for script tool with custom strong type connection,
# we still want to show 'CustomConnection' type.
# 2. Generate custom connection type when resolving tool in _tool_resolver, since we rely on it to convert the
# custom connection to custom strong type connection.
if not gen_custom_type_conn:
custom_type_conn = None

return (
InputDefinition(
type=typ,
default=value_to_str(default_value),
description=None,
enum=enum,
custom_type=custom_type_conn,
),
is_connection,
)


def function_to_interface(f: Callable, initialize_inputs=None) -> tuple:
def function_to_interface(f: Callable, initialize_inputs=None, gen_custom_type_conn=False) -> tuple:
sign = inspect.signature(f)
all_inputs = {}
input_defs = {}
Expand All @@ -83,7 +121,7 @@ def function_to_interface(f: Callable, initialize_inputs=None) -> tuple:
)
# Resolve inputs to definitions.
for k, v in all_inputs.items():
input_def, is_connection = param_to_definition(v)
input_def, is_connection = param_to_definition(v, gen_custom_type_conn=gen_custom_type_conn)
input_defs[k] = input_def
if is_connection:
connection_types.append(input_def.type)
Expand Down
4 changes: 4 additions & 0 deletions src/promptflow/promptflow/contracts/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ class InputDefinition:
default: str = None
description: str = None
enum: List[str] = None
# Param 'custom_type' is currently used for inputs of custom strong type connection.
# For a custom strong type connection input, the type should be 'CustomConnection',
# while the custom_type should be the custom strong type connection class name.
custom_type: List[str] = None

def serialize(self) -> dict:
Expand Down Expand Up @@ -285,6 +288,7 @@ def _deserialize_type(v):
data.get("default", ""),
data.get("description", ""),
data.get("enum", []),
data.get("custom_type", []),
16oeahr marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down
43 changes: 34 additions & 9 deletions src/promptflow/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import copy
import inspect
import types
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, List, Optional

from promptflow._core.connection_manager import ConnectionManager
from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping
from promptflow._sdk.entities import CustomConnection
from promptflow._utils.tool_utils import get_inputs_for_prompt_template, get_prompt_param_name_from_func
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType
from promptflow.contracts.tool import ConnectionType, Tool, ToolType, ValueType
Expand Down Expand Up @@ -53,10 +53,6 @@ def _convert_to_connection_value(self, k: str, v: InputAssignment, node: Node, c
connection_value = self._connection_manager.get(v.value)
if not connection_value:
raise ConnectionNotFound(f"Connection {v.value} not found for node {node.name!r} input {k!r}.")

if isinstance(connection_value, CustomConnection) and connection_value._is_custom_strong_type():
return connection_value._convert_to_custom_strong_type()

# Check if type matched
if not any(type(connection_value).__name__ == typ for typ in conn_types):
msg = (
Expand All @@ -66,7 +62,23 @@ def _convert_to_connection_value(self, k: str, v: InputAssignment, node: Node, c
raise NodeInputValidationError(message=msg)
return connection_value

def _convert_node_literal_input_types(self, node: Node, tool: Tool):
def _convert_to_custom_strong_type_connection_value(
self, k: str, v: InputAssignment, node: Node, conn_types: List[str], module: types.ModuleType
):
if conn_types is None:
msg = f"Input '{k}' for node '{node.name}' has invalid types: None."
raise NodeInputValidationError(message=msg)
connection_value = self._connection_manager.get(v.value)
if not connection_value:
raise ConnectionNotFound(f"Connection {v.value} not found for node {node.name!r} input {k!r}.")

custom_defined_connection_class = None
if node.source.type == ToolSourceType.Code:
custom_type_class_name = conn_types[0]
custom_defined_connection_class = getattr(module, custom_type_class_name)
return connection_value._convert_to_custom_strong_type(to_class=custom_defined_connection_class)

def _convert_node_literal_input_types(self, node: Node, tool: Tool, module: types.ModuleType = None):
updated_inputs = {
k: v
for k, v in node.inputs.items()
Expand All @@ -81,7 +93,12 @@ def _convert_node_literal_input_types(self, node: Node, tool: Tool):
value_type = tool_input.type[0]
updated_inputs[k] = InputAssignment(value=v.value, value_type=InputValueType.LITERAL)
if ConnectionType.is_connection_class_name(value_type):
updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type)
if tool_input.custom_type:
updated_inputs[k].value = self._convert_to_custom_strong_type_connection_value(
k, v, node, tool_input.custom_type, module=module
)
else:
updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type)
elif isinstance(value_type, ValueType):
try:
updated_inputs[k].value = value_type.parse(v.value)
Expand Down Expand Up @@ -221,9 +238,17 @@ def _resolve_llm_connection_to_inputs(self, node: Node, tool: Tool) -> Node:
)

def _resolve_script_node(self, node: Node, convert_input_types=False) -> ResolvedTool:
f, tool = self._tool_loader.load_tool_for_script_node(node)
m, f, tool = self._tool_loader.load_tool_for_script_node(node)
# We only want to load script tool module once.
# Reloading the same module changes the ID of the class, which can cause issues with isinstance() checks.
# This is important when working with connection class checks. For instance, in user tool script it writes:
# isinstance(conn, MyCustomConnection)
# Custom defined script tool and custom defined strong type connection are in the same module.
# The first time to load the module is in above line when loading a tool.
# We need the module again when converting the custom connection to strong type when converting input types.
# To avoid reloading, pass the loaded module to _convert_node_literal_input_types as an arg.
if convert_input_types:
node = self._convert_node_literal_input_types(node, tool)
node = self._convert_node_literal_input_types(node, tool, m)
return ResolvedTool(node=node, definition=tool, callable=f, init_args={})

def _resolve_package_node(self, node: Node, convert_input_types=False) -> ResolvedTool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,6 @@ def test_gen_tool_by_source_error(self, tool_source, tool_type, error_code, erro
gen_tool_by_source("fake_name", tool_source, tool_type, working_dir),
assert str(ex.value) == error_message

def test(self):
16oeahr marked this conversation as resolved.
Show resolved Hide resolved
tools, specs, templates = collect_package_tools_and_connections()
from promptflow._sdk._utils import refresh_connections_dir

refresh_connections_dir(specs, templates)

def test_collect_package_tools_and_connections(self, install_custom_tool_pkg):
# Need to reload pkg_resources to get the latest installed tools
import importlib
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import inspect
from typing import Union

import pytest

from promptflow._utils.tool_utils import function_to_interface
from promptflow._utils.tool_utils import function_to_interface, param_to_definition
from promptflow.connections import AzureOpenAIConnection, CustomConnection
from promptflow.contracts.tool import ValueType

Expand All @@ -24,3 +27,56 @@ def func(input_str: str):
with pytest.raises(Exception) as exec_info:
function_to_interface(func, {"input_str": "test"})
assert "Duplicate inputs found from" in exec_info.value.args[0]

def test_param_to_definition(self):
from promptflow._sdk.entities import CustomStrongTypeConnection
from promptflow.contracts.tool import Secret

class MyFirstConnection(CustomStrongTypeConnection):
api_key: Secret
api_base: str

class MySecondConnection(CustomStrongTypeConnection):
api_key: Secret
api_base: str

def some_func(
conn1: MyFirstConnection,
conn2: Union[CustomConnection, MyFirstConnection],
conn3: Union[MyFirstConnection, CustomConnection],
conn4: Union[MyFirstConnection, MySecondConnection],
conn5: CustomConnection,
conn6: Union[CustomConnection, int],
conn7: Union[MyFirstConnection, int],
):
pass

sig = inspect.signature(some_func)

input_def, _ = param_to_definition(sig.parameters.get("conn1"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type == ["MyFirstConnection"]

input_def, _ = param_to_definition(sig.parameters.get("conn2"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type == ["MyFirstConnection"]

input_def, _ = param_to_definition(sig.parameters.get("conn3"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type == ["MyFirstConnection"]

input_def, _ = param_to_definition(sig.parameters.get("conn4"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type == ["MyFirstConnection", "MySecondConnection"]

input_def, _ = param_to_definition(sig.parameters.get("conn5"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type is None

input_def, _ = param_to_definition(sig.parameters.get("conn6"), gen_custom_type_conn=True)
assert input_def.type == [ValueType.OBJECT]
assert input_def.custom_type is None

input_def, _ = param_to_definition(sig.parameters.get("conn7"), gen_custom_type_conn=True)
assert input_def.type == [ValueType.OBJECT]
assert input_def.custom_type is None
Loading