Skip to content

Commit

Permalink
support custom strong type connection in script tool
Browse files Browse the repository at this point in the history
  • Loading branch information
yalu4 committed Oct 12, 2023
1 parent 26235f3 commit c58062c
Show file tree
Hide file tree
Showing 15 changed files with 208 additions and 39 deletions.
20 changes: 17 additions & 3 deletions src/promptflow/promptflow/_core/tool_meta_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import inspect
import json
import re
import sys
import types
from dataclasses import asdict
from pathlib import Path
Expand Down Expand Up @@ -111,13 +112,13 @@ def collect_tool_methods_in_module(m):
return tools


def _parse_tool_from_function(f):
def _parse_tool_from_function(f, should_gen_custom_type=False):
if hasattr(f, "__tool") and isinstance(f.__tool, Tool):
return f.__tool
if hasattr(f, "__original_function"):
f = f.__original_function
try:
inputs, _, _ = function_to_interface(f)
inputs, _, _ = function_to_interface(f, should_gen_custom_type=should_gen_custom_type)
except Exception as e:
raise BadFunctionInterface(f"Failed to parse interface for tool {f.__name__}, reason: {e}") from e
class_name = None
Expand Down Expand Up @@ -149,7 +150,19 @@ def generate_python_tools_in_module_as_dict(module):
def load_python_module_from_file(src_file: Path):
# Here we hard code the module name as __pf_main__ since it is invoked as a main script in pf.
src_file = Path(src_file).resolve() # Make sure the path is absolute to align with python import behavior.
spec = importlib.util.spec_from_file_location("__pf_main__", location=src_file)
module_name = "__pf_main__"

# Reloading the same module changes the ID of the class, which can cause issues with isinstance() checks.
# This is important when working with connection class checks. For instance, in user tool script it writes:
# isinstance(conn, MyCustomConnection)
# Custom defined script tool and custom defined strong type connection are in the same module.
# The first time to load the module is when converting the custom connection to custom strong type.
# The second time to load the module is when _resolve_script_node. The isinstance() check will fail.
# To avoid this, return the already loaded module if it exists in sys.modules instead of reloading it.
if module_name in sys.modules and sys.modules[module_name].__file__ == str(src_file):
return sys.modules[module_name]

spec = importlib.util.spec_from_file_location(module_name, location=src_file)
if spec is None or spec.loader is None:
raise PythonLoaderNotFound(f"Failed to load python file '{src_file}', please make sure it is a valid .py file.")
m = importlib.util.module_from_spec(spec)
Expand All @@ -158,6 +171,7 @@ def load_python_module_from_file(src_file: Path):
except Exception as e:
# TODO: add stacktrace to additional info
raise PythonLoadError(f"Failed to load python module from file '{src_file}', reason: {e}.") from e
sys.modules[module_name] = m
return m


Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/_core/tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def load_tool_for_script_node(self, node: Node) -> Tuple[Callable, Tool]:
if m is None:
raise CustomToolSourceLoadError(f"Cannot load module from {path}.")
f = collect_tool_function_in_module(m)
return f, _parse_tool_from_function(f)
return f, _parse_tool_from_function(f, should_gen_custom_type=True)

def load_tool_for_llm_node(self, node: Node) -> Tool:
api_name = f"{node.provider}.{node.api}"
Expand Down
24 changes: 16 additions & 8 deletions src/promptflow/promptflow/_sdk/entities/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def _convert_to_custom(self):
return custom_connection

@classmethod
def _get_custom_keys(cls, data):
def _get_custom_keys(cls, data: Dict):
# The data could be either from yaml or from DB.
# If from yaml, 'custom_type' and 'module' are outside the configs of data.
# If from DB, 'custom_type' and 'module' are within the configs of data.
Expand Down Expand Up @@ -933,13 +933,21 @@ def _is_custom_strong_type(self):
and self.configs[CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY]
)

def _convert_to_custom_strong_type(self):
module_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY)
custom_type_class_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY)
import importlib

module = importlib.import_module(module_name)
custom_defined_connection_class = getattr(module, custom_type_class_name)
def _convert_to_custom_strong_type(self, custom_cls) -> CustomStrongTypeConnection:
# There are two scenarios to convert a custom connection to custom strong type connection:
# 1. The connection is created from a custom strong type connection template file.
# Custom type and module name are present in the configs.
# 2. The connection is created through SDK PFClient or a custom connection template file.
# Custom type and module name are not present in the configs. Module and class must be passed for conversion.
if not custom_cls:
module_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY)
import importlib

module = importlib.import_module(module_name)
custom_type_class_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY)
custom_defined_connection_class = getattr(module, custom_type_class_name)

custom_defined_connection_class = custom_cls or custom_defined_connection_class
connection_instance = custom_defined_connection_class(configs=self.configs, secrets=self.secrets)

return connection_instance
Expand Down
35 changes: 29 additions & 6 deletions src/promptflow/promptflow/_utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ def resolve_annotation(anno) -> Union[str, list]:
return args[0] if len(args) == 1 else args


def param_to_definition(param) -> (InputDefinition, bool):
def param_to_definition(param, should_gen_custom_type=False) -> (InputDefinition, bool):
default_value = param.default
# Get value type and enum from annotation
value_type = resolve_annotation(param.annotation)
enum = None
custom_type = None
# Get value type and enum from default if no annotation
if default_value is not inspect.Parameter.empty and value_type == inspect.Parameter.empty:
value_type = default_value.__class__ if isinstance(default_value, Enum) else type(default_value)
Expand All @@ -51,20 +52,42 @@ def param_to_definition(param) -> (InputDefinition, bool):
value_type = str
is_connection = False
if ConnectionType.is_connection_value(value_type):
typ = [value_type.__name__]
if ConnectionType.is_custom_strong_type(value_type):
typ = ["CustomConnection"]
custom_type = [value_type.__name__]
else:
typ = [value_type.__name__]
is_connection = True
elif isinstance(value_type, list):
if not all(ConnectionType.is_connection_value(t) for t in value_type):
typ = [ValueType.OBJECT]
else:
typ = [t.__name__ for t in value_type]
custom_connection_added = False
typ = []
custom_type = []
for t in value_type:
if ConnectionType.is_custom_strong_type(t):
if not custom_connection_added:
custom_connection_added = True
typ.append("CustomConnection")
custom_type.append(t.__name__)
else:
typ.append(t.__name__)
is_connection = True
else:
typ = [ValueType.from_type(value_type)]
return InputDefinition(type=typ, default=value_to_str(default_value), description=None, enum=enum), is_connection
# Do not generate custom type when generating flow.tools.json for script tool.
if not should_gen_custom_type:
custom_type = None
return (
InputDefinition(
type=typ, default=value_to_str(default_value), description=None, enum=enum, custom_type=custom_type
),
is_connection,
)


def function_to_interface(f: Callable, initialize_inputs=None) -> tuple:
def function_to_interface(f: Callable, initialize_inputs=None, should_gen_custom_type=False) -> tuple:
sign = inspect.signature(f)
all_inputs = {}
input_defs = {}
Expand All @@ -83,7 +106,7 @@ def function_to_interface(f: Callable, initialize_inputs=None) -> tuple:
)
# Resolve inputs to definitions.
for k, v in all_inputs.items():
input_def, is_connection = param_to_definition(v)
input_def, is_connection = param_to_definition(v, should_gen_custom_type=should_gen_custom_type)
input_defs[k] = input_def
if is_connection:
connection_types.append(input_def.type)
Expand Down
1 change: 1 addition & 0 deletions src/promptflow/promptflow/contracts/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def _deserialize_type(v):
data.get("default", ""),
data.get("description", ""),
data.get("enum", []),
data.get("custom_type", []),
)


Expand Down
27 changes: 21 additions & 6 deletions src/promptflow/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from typing import Callable, List, Optional

from promptflow._core.connection_manager import ConnectionManager
from promptflow._core.tool_meta_generator import load_python_module_from_file
from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping
from promptflow._sdk.entities import CustomConnection
from promptflow._utils.tool_utils import get_inputs_for_prompt_template, get_prompt_param_name_from_func
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType
from promptflow.contracts.tool import ConnectionType, Tool, ToolType, ValueType
Expand Down Expand Up @@ -53,10 +53,6 @@ def _convert_to_connection_value(self, k: str, v: InputAssignment, node: Node, c
connection_value = self._connection_manager.get(v.value)
if not connection_value:
raise ConnectionNotFound(f"Connection {v.value} not found for node {node.name!r} input {k!r}.")

if isinstance(connection_value, CustomConnection) and connection_value._is_custom_strong_type():
return connection_value._convert_to_custom_strong_type()

# Check if type matched
if not any(type(connection_value).__name__ == typ for typ in conn_types):
msg = (
Expand All @@ -66,6 +62,20 @@ def _convert_to_connection_value(self, k: str, v: InputAssignment, node: Node, c
raise NodeInputValidationError(message=msg)
return connection_value

def _convert_to_custom_strong_type_connection_value(
self, k: str, v: InputAssignment, node: Node, conn_types: List[ValueType], module_path=Optional[str]
):
connection_value = self._connection_manager.get(v.value)
if not connection_value:
raise ConnectionNotFound(f"Connection {v.value} not found for node {node.name!r} input {k!r}.")

custom_defined_connection_class = None
if node.source.type == ToolSourceType.Code:
custom_m = load_python_module_from_file(module_path)
custom_type_class_name = conn_types[0]
custom_defined_connection_class = getattr(custom_m, custom_type_class_name)
return connection_value._convert_to_custom_strong_type(custom_defined_connection_class)

def _convert_node_literal_input_types(self, node: Node, tool: Tool):
updated_inputs = {
k: v
Expand All @@ -81,7 +91,12 @@ def _convert_node_literal_input_types(self, node: Node, tool: Tool):
value_type = tool_input.type[0]
updated_inputs[k] = InputAssignment(value=v.value, value_type=InputValueType.LITERAL)
if ConnectionType.is_connection_class_name(value_type):
updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type)
if tool_input.custom_type:
updated_inputs[k].value = self._convert_to_custom_strong_type_connection_value(
k, v, node, tool_input.custom_type, module_path=node.source.path
)
else:
updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type)
elif isinstance(value_type, ValueType):
try:
updated_inputs[k].value = value_type.parse(v.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,6 @@ def test_gen_tool_by_source_error(self, tool_source, tool_type, error_code, erro
gen_tool_by_source("fake_name", tool_source, tool_type, working_dir),
assert str(ex.value) == error_message

def test(self):
tools, specs, templates = collect_package_tools_and_connections()
from promptflow._sdk._utils import refresh_connections_dir

refresh_connections_dir(specs, templates)

def test_collect_package_tools_and_connections(self, install_custom_tool_pkg):
# Need to reload pkg_resources to get the latest installed tools
import importlib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ def test_basic_flow_run_bulk_without_env(self, pf, runtime) -> None:
assert isinstance(run, Run)

@pytest.mark.skip("Custom tool pkg and promptprompt pkg with CustomStrongTypeConnection not installed on runtime.")
def test_basic_flow_run_with_custom_strong_type_connection(self, pf, runtime) -> None:
def test_basic_flow_with_package_tool_with_custom_strong_type_connection(self, pf, runtime) -> None:
name = str(uuid.uuid4())
run_pf_command(
"run",
"create",
"--flow",
f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow",
f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection",
"--data",
f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow/data.jsonl",
f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection/data.jsonl",
"--name",
name,
pf=pf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,58 @@ def test_flow_generate_tools_meta(self, pf) -> None:
"package": {},
}
assert tools_error == {}

def test_flow_generate_tools_meta_with_pkg_tool_with_custom_strong_type_connection(self, pf) -> None:
source = f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection"

tools_meta, tools_error = pf.flows._generate_tools_meta(source)

assert tools_error == {}
assert tools_meta["code"] == {}
assert tools_meta["package"] == {
"my_tool_package.tools.my_tool_1.my_tool": {
"function": "my_tool",
"inputs": {
"connection": {
"type": ["CustomConnection"],
"custom_type": ["MyFirstConnection", "MySecondConnection"],
},
"input_text": {"type": ["string"]},
},
"module": "my_tool_package.tools.my_tool_1",
"name": "My First Tool",
"description": "This is my first tool",
"type": "python",
"package": "test-custom-tools",
"package_version": "0.0.2",
},
"my_tool_package.tools.my_tool_2.MyTool.my_tool": {
"class_name": "MyTool",
"function": "my_tool",
"inputs": {
"connection": {"type": ["CustomConnection"], "custom_type": ["MySecondConnection"]},
"input_text": {"type": ["string"]},
},
"module": "my_tool_package.tools.my_tool_2",
"name": "My Second Tool",
"description": "This is my second tool",
"type": "python",
"package": "test-custom-tools",
"package_version": "0.0.2",
},
}

def test_flow_generate_tools_meta_with_script_tool_with_custom_strong_type_connection(self, pf) -> None:
source = f"{FLOWS_DIR}/flow_with_script_tool_with_custom_strong_type_connection"

tools_meta, tools_error = pf.flows._generate_tools_meta(source)
assert tools_error == {}
assert tools_meta["package"] == {}
assert tools_meta["code"] == {
"my_script_tool.py": {
"function": "my_tool",
"inputs": {"connection": {"type": ["CustomConnection"]}, "input_param": {"type": ["string"]}},
"source": "my_script_tool.py",
"type": "python",
}
}
8 changes: 5 additions & 3 deletions src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ def test_custom_connection_overwrite(self, local_client, local_custom_connection
)
assert "Connection with name new_connection not found" in str(e.value)

def test_custom_strong_type_connection_basic_flow(self, install_custom_tool_pkg, local_client, pf):
def test_basic_flow_with_package_tool_with_custom_strong_type_connection(
self, install_custom_tool_pkg, local_client, pf
):
# Need to reload pkg_resources to get the latest installed tools
import importlib

Expand All @@ -256,8 +258,8 @@ def test_custom_strong_type_connection_basic_flow(self, install_custom_tool_pkg,
importlib.reload(pkg_resources)

result = pf.run(
flow=f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow",
data=f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow/data.jsonl",
flow=f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection",
data=f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection/data.jsonl",
connections={"My_First_Tool_00f8": {"connection": "custom_strong_type_connection"}},
)
run = local_client.runs.get(name=result.name)
Expand Down
24 changes: 21 additions & 3 deletions src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_pf_test_flow(self):
result = _client.test(flow=f"{FLOWS_DIR}/web_classification")
assert all([key in FLOW_RESULT_KEYS for key in result])

def test_pf_test_flow_with_custom_strong_type_connection(self, install_custom_tool_pkg):
def test_pf_test_flow_with_package_tool_with_custom_strong_type_connection(self, install_custom_tool_pkg):
# Need to reload pkg_resources to get the latest installed tools
import importlib

Expand All @@ -42,16 +42,34 @@ def test_pf_test_flow_with_custom_strong_type_connection(self, install_custom_to
importlib.reload(pkg_resources)

inputs = {"text": "Hello World!"}
flow_path = Path(f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow").absolute()
flow_path = Path(f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection").absolute()

# Test that connection would be custom strong type in flow
result = _client.test(flow=flow_path, inputs=inputs)
assert result == {"out": "connection_value is MyFirstConnection: True"}

# Test that connection
# Test node run
result = _client.test(flow=flow_path, inputs={"input_text": "Hello World!"}, node="My_Second_Tool_usi3")
assert result == "Hello World!This is my first custom connection."

def test_pf_test_flow_with_script_tool_with_custom_strong_type_connection(self):
# Prepare custom connection
from promptflow.connections import CustomConnection

conn = CustomConnection(name="custom_connection_2", secrets={"api_key": "test"}, configs={"api_url": "test"})
_client.connections.create_or_update(conn)

inputs = {"text": "Hello World!"}
flow_path = Path(f"{FLOWS_DIR}/flow_with_script_tool_with_custom_strong_type_connection").absolute()

# Test that connection would be custom strong type in flow
result = _client.test(flow=flow_path, inputs=inputs)
assert result == {"out": "connection_value is MyCustomConnection: True"}

# Test node run
result = _client.test(flow=flow_path, inputs={"input_param": "Hello World!"}, node="my_script_tool")
assert result == "connection_value is MyCustomConnection: True"

def test_pf_test_with_streaming_output(self):
flow_path = Path(f"{FLOWS_DIR}/chat_flow_with_stream_output")
result = _client.test(flow=flow_path)
Expand Down
Loading

0 comments on commit c58062c

Please sign in to comment.