diff --git a/src/promptflow/promptflow/_core/tool_meta_generator.py b/src/promptflow/promptflow/_core/tool_meta_generator.py index aebe5b57908..f1a664c6e2f 100644 --- a/src/promptflow/promptflow/_core/tool_meta_generator.py +++ b/src/promptflow/promptflow/_core/tool_meta_generator.py @@ -120,13 +120,13 @@ def collect_tool_methods_in_module(m): return tools -def _parse_tool_from_function(f): +def _parse_tool_from_function(f, gen_custom_type_conn=False): if hasattr(f, "__tool") and isinstance(f.__tool, Tool): return f.__tool if hasattr(f, "__original_function"): f = f.__original_function try: - inputs, _, _ = function_to_interface(f) + inputs, _, _ = function_to_interface(f, gen_custom_type_conn=gen_custom_type_conn) except Exception as e: error_type_and_message = f"({e.__class__.__name__}) {e}" raise BadFunctionInterface( diff --git a/src/promptflow/promptflow/_core/tools_manager.py b/src/promptflow/promptflow/_core/tools_manager.py index c2c91f32b4e..35404791515 100644 --- a/src/promptflow/promptflow/_core/tools_manager.py +++ b/src/promptflow/promptflow/_core/tools_manager.py @@ -7,6 +7,7 @@ import inspect import logging import traceback +import types from functools import partial from pathlib import Path from typing import Callable, List, Mapping, Optional, Tuple, Union @@ -347,7 +348,7 @@ def load_tool_for_node(self, node: Node) -> Tool: if node.source.type == ToolSourceType.Package: return self.load_tool_for_package_node(node) elif node.source.type == ToolSourceType.Code: - _, tool = self.load_tool_for_script_node(node) + _, _, tool = self.load_tool_for_script_node(node) return tool raise NotImplementedError(f"Tool source type {node.source.type} for python tool is not supported yet.") elif node.type == ToolType.CUSTOM_LLM: @@ -366,7 +367,7 @@ def load_tool_for_package_node(self, node: Node) -> Tool: target=ErrorTarget.EXECUTOR, ) - def load_tool_for_script_node(self, node: Node) -> Tuple[Callable, Tool]: + def load_tool_for_script_node(self, node: Node) -> Tuple[types.ModuleType, Callable, Tool]: if node.source.path is None: raise UserErrorException(f"Node {node.name} does not have source path defined.") path = node.source.path @@ -374,7 +375,7 @@ def load_tool_for_script_node(self, node: Node) -> Tuple[Callable, Tool]: if m is None: raise CustomToolSourceLoadError(f"Cannot load module from {path}.") f = collect_tool_function_in_module(m) - return f, _parse_tool_from_function(f) + return m, f, _parse_tool_from_function(f, gen_custom_type_conn=True) def load_tool_for_llm_node(self, node: Node) -> Tool: api_name = f"{node.provider}.{node.api}" diff --git a/src/promptflow/promptflow/_sdk/entities/_connection.py b/src/promptflow/promptflow/_sdk/entities/_connection.py index 987987367ea..86346e90ad5 100644 --- a/src/promptflow/promptflow/_sdk/entities/_connection.py +++ b/src/promptflow/promptflow/_sdk/entities/_connection.py @@ -715,7 +715,7 @@ def _convert_to_custom(self): return custom_connection @classmethod - def _get_custom_keys(cls, data): + def _get_custom_keys(cls, data: Dict): # The data could be either from yaml or from DB. # If from yaml, 'custom_type' and 'module' are outside the configs of data. # If from DB, 'custom_type' and 'module' are within the configs of data. @@ -933,13 +933,28 @@ def _is_custom_strong_type(self): and self.configs[CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY] ) - def _convert_to_custom_strong_type(self): - module_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY) - custom_type_class_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY) - import importlib + def _convert_to_custom_strong_type(self, to_class) -> CustomStrongTypeConnection: + # There are two scenarios to convert a custom connection to custom strong type connection: + # 1. The connection is created from a custom strong type connection template file. + # Custom type and module name are present in the configs. + # 2. The connection is created through SDK PFClient or a custom connection template file. + # Custom type and module name are not present in the configs. Module and class must be passed for conversion. + if to_class and not isinstance(to_class, type): + raise TypeError(f"The converted type {to_class} must be a class type.") - module = importlib.import_module(module_name) - custom_defined_connection_class = getattr(module, custom_type_class_name) + if not to_class: + module_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY) + import importlib + + module = importlib.import_module(module_name) + custom_type_class_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY) + custom_defined_connection_class = getattr(module, custom_type_class_name) + + if to_class and issubclass(to_class, CustomConnection): + # No need to convert. + return self + + custom_defined_connection_class = to_class or custom_defined_connection_class connection_instance = custom_defined_connection_class(configs=self.configs, secrets=self.secrets) return connection_instance diff --git a/src/promptflow/promptflow/_utils/tool_utils.py b/src/promptflow/promptflow/_utils/tool_utils.py index 60e1b765663..ec1b136a8f2 100644 --- a/src/promptflow/promptflow/_utils/tool_utils.py +++ b/src/promptflow/promptflow/_utils/tool_utils.py @@ -37,11 +37,12 @@ def resolve_annotation(anno) -> Union[str, list]: return args[0] if len(args) == 1 else args -def param_to_definition(param) -> (InputDefinition, bool): +def param_to_definition(param, gen_custom_type_conn=False) -> (InputDefinition, bool): default_value = param.default # Get value type and enum from annotation value_type = resolve_annotation(param.annotation) enum = None + custom_type_conn = None # Get value type and enum from default if no annotation if default_value is not inspect.Parameter.empty and value_type == inspect.Parameter.empty: value_type = default_value.__class__ if isinstance(default_value, Enum) else type(default_value) @@ -51,20 +52,57 @@ def param_to_definition(param) -> (InputDefinition, bool): value_type = str is_connection = False if ConnectionType.is_connection_value(value_type): - typ = [value_type.__name__] + if ConnectionType.is_custom_strong_type(value_type): + typ = ["CustomConnection"] + custom_type_conn = [value_type.__name__] + else: + typ = [value_type.__name__] is_connection = True elif isinstance(value_type, list): if not all(ConnectionType.is_connection_value(t) for t in value_type): typ = [ValueType.OBJECT] else: - typ = [t.__name__ for t in value_type] + custom_connection_added = False + typ = [] + custom_type_conn = [] + for t in value_type: + # Add 'CustomConnection' to typ list when custom strong type connection exists. Collect all custom types + if ConnectionType.is_custom_strong_type(t): + if not custom_connection_added: + custom_connection_added = True + typ.append("CustomConnection") + custom_type_conn.append(t.__name__) + else: + if t.__name__ != "CustomConnection": + typ.append(t.__name__) + elif not custom_connection_added: + custom_connection_added = True + typ.append(t.__name__) is_connection = True else: typ = [ValueType.from_type(value_type)] - return InputDefinition(type=typ, default=value_to_str(default_value), description=None, enum=enum), is_connection + + # 1. Do not generate custom type when generating flow.tools.json for script tool. + # Extension would show custom type if it exists. While for script tool with custom strong type connection, + # we still want to show 'CustomConnection' type. + # 2. Generate custom connection type when resolving tool in _tool_resolver, since we rely on it to convert the + # custom connection to custom strong type connection. + if not gen_custom_type_conn: + custom_type_conn = None + + return ( + InputDefinition( + type=typ, + default=value_to_str(default_value), + description=None, + enum=enum, + custom_type=custom_type_conn, + ), + is_connection, + ) -def function_to_interface(f: Callable, initialize_inputs=None) -> tuple: +def function_to_interface(f: Callable, initialize_inputs=None, gen_custom_type_conn=False) -> tuple: sign = inspect.signature(f) all_inputs = {} input_defs = {} @@ -83,7 +121,7 @@ def function_to_interface(f: Callable, initialize_inputs=None) -> tuple: ) # Resolve inputs to definitions. for k, v in all_inputs.items(): - input_def, is_connection = param_to_definition(v) + input_def, is_connection = param_to_definition(v, gen_custom_type_conn=gen_custom_type_conn) input_defs[k] = input_def if is_connection: connection_types.append(input_def.type) diff --git a/src/promptflow/promptflow/contracts/tool.py b/src/promptflow/promptflow/contracts/tool.py index 199a5a36b23..6fe7cb9a0cd 100644 --- a/src/promptflow/promptflow/contracts/tool.py +++ b/src/promptflow/promptflow/contracts/tool.py @@ -241,6 +241,9 @@ class InputDefinition: default: str = None description: str = None enum: List[str] = None + # Param 'custom_type' is currently used for inputs of custom strong type connection. + # For a custom strong type connection input, the type should be 'CustomConnection', + # while the custom_type should be the custom strong type connection class name. custom_type: List[str] = None def serialize(self) -> dict: @@ -285,6 +288,7 @@ def _deserialize_type(v): data.get("default", ""), data.get("description", ""), data.get("enum", []), + data.get("custom_type", []), ) diff --git a/src/promptflow/promptflow/executor/_tool_resolver.py b/src/promptflow/promptflow/executor/_tool_resolver.py index 62ce9db3f84..501534edd71 100644 --- a/src/promptflow/promptflow/executor/_tool_resolver.py +++ b/src/promptflow/promptflow/executor/_tool_resolver.py @@ -4,6 +4,7 @@ import copy import inspect +import types from dataclasses import dataclass from functools import partial from pathlib import Path @@ -11,7 +12,6 @@ from promptflow._core.connection_manager import ConnectionManager from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping -from promptflow._sdk.entities import CustomConnection from promptflow._utils.tool_utils import get_inputs_for_prompt_template, get_prompt_param_name_from_func from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType from promptflow.contracts.tool import ConnectionType, Tool, ToolType, ValueType @@ -53,10 +53,6 @@ def _convert_to_connection_value(self, k: str, v: InputAssignment, node: Node, c connection_value = self._connection_manager.get(v.value) if not connection_value: raise ConnectionNotFound(f"Connection {v.value} not found for node {node.name!r} input {k!r}.") - - if isinstance(connection_value, CustomConnection) and connection_value._is_custom_strong_type(): - return connection_value._convert_to_custom_strong_type() - # Check if type matched if not any(type(connection_value).__name__ == typ for typ in conn_types): msg = ( @@ -66,7 +62,23 @@ def _convert_to_connection_value(self, k: str, v: InputAssignment, node: Node, c raise NodeInputValidationError(message=msg) return connection_value - def _convert_node_literal_input_types(self, node: Node, tool: Tool): + def _convert_to_custom_strong_type_connection_value( + self, k: str, v: InputAssignment, node: Node, conn_types: List[str], module: types.ModuleType + ): + if conn_types is None: + msg = f"Input '{k}' for node '{node.name}' has invalid types: None." + raise NodeInputValidationError(message=msg) + connection_value = self._connection_manager.get(v.value) + if not connection_value: + raise ConnectionNotFound(f"Connection {v.value} not found for node {node.name!r} input {k!r}.") + + custom_defined_connection_class = None + if node.source.type == ToolSourceType.Code: + custom_type_class_name = conn_types[0] + custom_defined_connection_class = getattr(module, custom_type_class_name) + return connection_value._convert_to_custom_strong_type(to_class=custom_defined_connection_class) + + def _convert_node_literal_input_types(self, node: Node, tool: Tool, module: types.ModuleType = None): updated_inputs = { k: v for k, v in node.inputs.items() @@ -81,7 +93,12 @@ def _convert_node_literal_input_types(self, node: Node, tool: Tool): value_type = tool_input.type[0] updated_inputs[k] = InputAssignment(value=v.value, value_type=InputValueType.LITERAL) if ConnectionType.is_connection_class_name(value_type): - updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type) + if tool_input.custom_type: + updated_inputs[k].value = self._convert_to_custom_strong_type_connection_value( + k, v, node, tool_input.custom_type, module=module + ) + else: + updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type) elif isinstance(value_type, ValueType): try: updated_inputs[k].value = value_type.parse(v.value) @@ -221,9 +238,17 @@ def _resolve_llm_connection_to_inputs(self, node: Node, tool: Tool) -> Node: ) def _resolve_script_node(self, node: Node, convert_input_types=False) -> ResolvedTool: - f, tool = self._tool_loader.load_tool_for_script_node(node) + m, f, tool = self._tool_loader.load_tool_for_script_node(node) + # We only want to load script tool module once. + # Reloading the same module changes the ID of the class, which can cause issues with isinstance() checks. + # This is important when working with connection class checks. For instance, in user tool script it writes: + # isinstance(conn, MyCustomConnection) + # Custom defined script tool and custom defined strong type connection are in the same module. + # The first time to load the module is in above line when loading a tool. + # We need the module again when converting the custom connection to strong type when converting input types. + # To avoid reloading, pass the loaded module to _convert_node_literal_input_types as an arg. if convert_input_types: - node = self._convert_node_literal_input_types(node, tool) + node = self._convert_node_literal_input_types(node, tool, m) return ResolvedTool(node=node, definition=tool, callable=f, init_args={}) def _resolve_package_node(self, node: Node, convert_input_types=False) -> ResolvedTool: diff --git a/src/promptflow/tests/executor/unittests/_core/test_tools_manager.py b/src/promptflow/tests/executor/unittests/_core/test_tools_manager.py index b06ec468b25..36a5ebe6559 100644 --- a/src/promptflow/tests/executor/unittests/_core/test_tools_manager.py +++ b/src/promptflow/tests/executor/unittests/_core/test_tools_manager.py @@ -137,12 +137,6 @@ def test_gen_tool_by_source_error(self, tool_source, tool_type, error_code, erro gen_tool_by_source("fake_name", tool_source, tool_type, working_dir), assert str(ex.value) == error_message - def test(self): - tools, specs, templates = collect_package_tools_and_connections() - from promptflow._sdk._utils import refresh_connections_dir - - refresh_connections_dir(specs, templates) - def test_collect_package_tools_and_connections(self, install_custom_tool_pkg): # Need to reload pkg_resources to get the latest installed tools import importlib diff --git a/src/promptflow/tests/executor/unittests/_utils/test_tool_utils.py b/src/promptflow/tests/executor/unittests/_utils/test_tool_utils.py index fe547980784..845d904c838 100644 --- a/src/promptflow/tests/executor/unittests/_utils/test_tool_utils.py +++ b/src/promptflow/tests/executor/unittests/_utils/test_tool_utils.py @@ -1,6 +1,9 @@ +import inspect +from typing import Union + import pytest -from promptflow._utils.tool_utils import function_to_interface +from promptflow._utils.tool_utils import function_to_interface, param_to_definition from promptflow.connections import AzureOpenAIConnection, CustomConnection from promptflow.contracts.tool import ValueType @@ -24,3 +27,56 @@ def func(input_str: str): with pytest.raises(Exception) as exec_info: function_to_interface(func, {"input_str": "test"}) assert "Duplicate inputs found from" in exec_info.value.args[0] + + def test_param_to_definition(self): + from promptflow._sdk.entities import CustomStrongTypeConnection + from promptflow.contracts.tool import Secret + + class MyFirstConnection(CustomStrongTypeConnection): + api_key: Secret + api_base: str + + class MySecondConnection(CustomStrongTypeConnection): + api_key: Secret + api_base: str + + def some_func( + conn1: MyFirstConnection, + conn2: Union[CustomConnection, MyFirstConnection], + conn3: Union[MyFirstConnection, CustomConnection], + conn4: Union[MyFirstConnection, MySecondConnection], + conn5: CustomConnection, + conn6: Union[CustomConnection, int], + conn7: Union[MyFirstConnection, int], + ): + pass + + sig = inspect.signature(some_func) + + input_def, _ = param_to_definition(sig.parameters.get("conn1"), gen_custom_type_conn=True) + assert input_def.type == ["CustomConnection"] + assert input_def.custom_type == ["MyFirstConnection"] + + input_def, _ = param_to_definition(sig.parameters.get("conn2"), gen_custom_type_conn=True) + assert input_def.type == ["CustomConnection"] + assert input_def.custom_type == ["MyFirstConnection"] + + input_def, _ = param_to_definition(sig.parameters.get("conn3"), gen_custom_type_conn=True) + assert input_def.type == ["CustomConnection"] + assert input_def.custom_type == ["MyFirstConnection"] + + input_def, _ = param_to_definition(sig.parameters.get("conn4"), gen_custom_type_conn=True) + assert input_def.type == ["CustomConnection"] + assert input_def.custom_type == ["MyFirstConnection", "MySecondConnection"] + + input_def, _ = param_to_definition(sig.parameters.get("conn5"), gen_custom_type_conn=True) + assert input_def.type == ["CustomConnection"] + assert input_def.custom_type is None + + input_def, _ = param_to_definition(sig.parameters.get("conn6"), gen_custom_type_conn=True) + assert input_def.type == [ValueType.OBJECT] + assert input_def.custom_type is None + + input_def, _ = param_to_definition(sig.parameters.get("conn7"), gen_custom_type_conn=True) + assert input_def.type == [ValueType.OBJECT] + assert input_def.custom_type is None diff --git a/src/promptflow/tests/executor/unittests/executor/test_tool_resolver.py b/src/promptflow/tests/executor/unittests/executor/test_tool_resolver.py index b300410ed35..d983cd05696 100644 --- a/src/promptflow/tests/executor/unittests/executor/test_tool_resolver.py +++ b/src/promptflow/tests/executor/unittests/executor/test_tool_resolver.py @@ -1,11 +1,13 @@ +import sys from pathlib import Path import pytest from promptflow._core.tools_manager import ToolLoader +from promptflow._sdk.entities import CustomConnection, CustomStrongTypeConnection from promptflow.connections import AzureOpenAIConnection from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType -from promptflow.contracts.tool import InputDefinition, Tool, ToolType, ValueType +from promptflow.contracts.tool import InputDefinition, Secret, Tool, ToolType, ValueType from promptflow.contracts.types import PromptTemplate from promptflow.exceptions import UserErrorException from promptflow.executor._errors import ( @@ -23,6 +25,11 @@ WRONG_REQUESTS_PATH = TEST_ROOT / "test_configs/executor_wrong_requests" +class MyFirstCSTConnection(CustomStrongTypeConnection): + api_key: Secret + api_base: str + + @pytest.mark.unittest class TestToolResolver: @pytest.fixture @@ -306,7 +313,7 @@ def mock_python_func(conn: AzureOpenAIConnection, prompt: PromptTemplate, **kwar tool_loader = ToolLoader(working_dir=None) tool = Tool(name="mock", type=ToolType.PYTHON, inputs={"conn": InputDefinition(type=["AzureOpenAIConnection"])}) - mocker.patch.object(tool_loader, "load_tool_for_script_node", return_value=(mock_python_func, tool)) + mocker.patch.object(tool_loader, "load_tool_for_script_node", return_value=(None, mock_python_func, tool)) connections = {"conn_name": {"type": "AzureOpenAIConnection", "value": {"api_key": "mock", "api_base": "mock"}}} tool_resolver = ToolResolver(working_dir=None, connections=connections) @@ -387,3 +394,25 @@ def mock_package_func(prompt: PromptTemplate, **kwargs): resolved_tool = tool_resolver._integrate_prompt_in_package_node(node, resolved_tool) kwargs = {k: v.value for k, v in resolved_tool.node.inputs.items()} assert resolved_tool.callable(**kwargs) == "Hello World!" + + @pytest.mark.parametrize( + "conn_types, expected_type", + [ + (["MyFirstCSTConnection"], MyFirstCSTConnection), + (["CustomConnection", "MyFirstCSTConnection"], CustomConnection), + (["CustomConnection", "MyFirstCSTConnection", "MySecondCSTConnection"], CustomConnection), + (["MyFirstCSTConnection", "MySecondCSTConnection"], MyFirstCSTConnection), + ], + ) + def test_convert_to_custom_strong_type_connection_value(self, conn_types: list[str], expected_type, mocker): + connections = {"conn_name": {"type": "CustomConnection", "value": {"api_key": "mock", "api_base": "mock"}}} + tool_resolver = ToolResolver(working_dir=None, connections=connections) + + node = mocker.Mock(name="node", tool=None, inputs={}) + node.type = ToolType.PYTHON + node.source = mocker.Mock(type=ToolSourceType.Code) + m = sys.modules[__name__] + v = InputAssignment(value="conn_name", value_type=InputValueType.LITERAL) + actual = tool_resolver._convert_to_custom_strong_type_connection_value("conn_name", v, node, conn_types, m) + assert isinstance(actual, expected_type) + assert actual.api_base == "mock" diff --git a/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_cli_with_azure.py b/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_cli_with_azure.py index 847e12e4cdc..5d46b472b80 100644 --- a/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_cli_with_azure.py +++ b/src/promptflow/tests/sdk_cli_azure_test/e2etests/test_cli_with_azure.py @@ -57,15 +57,15 @@ def test_basic_flow_run_bulk_without_env(self, pf, runtime) -> None: assert isinstance(run, Run) @pytest.mark.skip("Custom tool pkg and promptprompt pkg with CustomStrongTypeConnection not installed on runtime.") - def test_basic_flow_run_with_custom_strong_type_connection(self, pf, runtime) -> None: + def test_basic_flow_with_package_tool_with_custom_strong_type_connection(self, pf, runtime) -> None: name = str(uuid.uuid4()) run_pf_command( "run", "create", "--flow", - f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow", + f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection", "--data", - f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow/data.jsonl", + f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection/data.jsonl", "--name", name, pf=pf, diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_custom_strong_type_connection.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_custom_strong_type_connection.py index b9a70e2bb8a..c184ba7994e 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_custom_strong_type_connection.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_custom_strong_type_connection.py @@ -129,8 +129,9 @@ def test_connection_get_and_update_with_key(self): assert conn.configs["api_base"] == "test" result = _client.connections.create_or_update(conn) - converted_conn = result._convert_to_custom_strong_type() + converted_conn = result._convert_to_custom_strong_type(MyCustomConnection) + assert isinstance(converted_conn, MyCustomConnection) assert converted_conn.api_base == "test" converted_conn.api_base = "test2" assert converted_conn.api_base == "test2" diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_local_operations.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_local_operations.py index b344edd02bd..6b9b62b5938 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_local_operations.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_local_operations.py @@ -405,3 +405,58 @@ def test_flow_generate_tools_meta(self, pf) -> None: "package": {}, } assert tools_error == {} + + def test_flow_generate_tools_meta_with_pkg_tool_with_custom_strong_type_connection(self, pf) -> None: + source = f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection" + + tools_meta, tools_error = pf.flows._generate_tools_meta(source) + + assert tools_error == {} + assert tools_meta["code"] == {} + assert tools_meta["package"] == { + "my_tool_package.tools.my_tool_1.my_tool": { + "function": "my_tool", + "inputs": { + "connection": { + "type": ["CustomConnection"], + "custom_type": ["MyFirstConnection", "MySecondConnection"], + }, + "input_text": {"type": ["string"]}, + }, + "module": "my_tool_package.tools.my_tool_1", + "name": "My First Tool", + "description": "This is my first tool", + "type": "python", + "package": "test-custom-tools", + "package_version": "0.0.2", + }, + "my_tool_package.tools.my_tool_2.MyTool.my_tool": { + "class_name": "MyTool", + "function": "my_tool", + "inputs": { + "connection": {"type": ["CustomConnection"], "custom_type": ["MySecondConnection"]}, + "input_text": {"type": ["string"]}, + }, + "module": "my_tool_package.tools.my_tool_2", + "name": "My Second Tool", + "description": "This is my second tool", + "type": "python", + "package": "test-custom-tools", + "package_version": "0.0.2", + }, + } + + def test_flow_generate_tools_meta_with_script_tool_with_custom_strong_type_connection(self, pf) -> None: + source = f"{FLOWS_DIR}/flow_with_script_tool_with_custom_strong_type_connection" + + tools_meta, tools_error = pf.flows._generate_tools_meta(source) + assert tools_error == {} + assert tools_meta["package"] == {} + assert tools_meta["code"] == { + "my_script_tool.py": { + "function": "my_tool", + "inputs": {"connection": {"type": ["CustomConnection"]}, "input_param": {"type": ["string"]}}, + "source": "my_script_tool.py", + "type": "python", + } + } diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py index 5205b1504e7..90b20ad2d69 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py @@ -247,7 +247,9 @@ def test_custom_connection_overwrite(self, local_client, local_custom_connection ) assert "Connection with name new_connection not found" in str(e.value) - def test_custom_strong_type_connection_basic_flow(self, install_custom_tool_pkg, local_client, pf): + def test_basic_flow_with_package_tool_with_custom_strong_type_connection( + self, install_custom_tool_pkg, local_client, pf + ): # Need to reload pkg_resources to get the latest installed tools import importlib @@ -256,13 +258,29 @@ def test_custom_strong_type_connection_basic_flow(self, install_custom_tool_pkg, importlib.reload(pkg_resources) result = pf.run( - flow=f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow", - data=f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow/data.jsonl", + flow=f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection", + data=f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection/data.jsonl", connections={"My_First_Tool_00f8": {"connection": "custom_strong_type_connection"}}, ) run = local_client.runs.get(name=result.name) assert run.status == "Completed" + def test_basic_flow_with_script_tool_with_custom_strong_type_connection( + self, install_custom_tool_pkg, local_client, pf + ): + # Prepare custom connection + from promptflow.connections import CustomConnection + + conn = CustomConnection(name="custom_connection_2", secrets={"api_key": "test"}, configs={"api_url": "test"}) + local_client.connections.create_or_update(conn) + + result = pf.run( + flow=f"{FLOWS_DIR}/flow_with_script_tool_with_custom_strong_type_connection", + data=f"{FLOWS_DIR}/flow_with_script_tool_with_custom_strong_type_connection/data.jsonl", + ) + run = local_client.runs.get(name=result.name) + assert run.status == "Completed" + def test_run_with_connection_overwrite_non_exist(self, local_client, local_aoai_connection, pf): # overwrite non_exist connection with pytest.raises(Exception) as e: diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py index 75ad028ea71..eaadb918a3b 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py @@ -33,7 +33,7 @@ def test_pf_test_flow(self): result = _client.test(flow=f"{FLOWS_DIR}/web_classification") assert all([key in FLOW_RESULT_KEYS for key in result]) - def test_pf_test_flow_with_custom_strong_type_connection(self, install_custom_tool_pkg): + def test_pf_test_flow_with_package_tool_with_custom_strong_type_connection(self, install_custom_tool_pkg): # Need to reload pkg_resources to get the latest installed tools import importlib @@ -42,16 +42,34 @@ def test_pf_test_flow_with_custom_strong_type_connection(self, install_custom_to importlib.reload(pkg_resources) inputs = {"text": "Hello World!"} - flow_path = Path(f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow").absolute() + flow_path = Path(f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection").absolute() # Test that connection would be custom strong type in flow result = _client.test(flow=flow_path, inputs=inputs) assert result == {"out": "connection_value is MyFirstConnection: True"} - # Test that connection + # Test node run result = _client.test(flow=flow_path, inputs={"input_text": "Hello World!"}, node="My_Second_Tool_usi3") assert result == "Hello World!This is my first custom connection." + def test_pf_test_flow_with_script_tool_with_custom_strong_type_connection(self): + # Prepare custom connection + from promptflow.connections import CustomConnection + + conn = CustomConnection(name="custom_connection_2", secrets={"api_key": "test"}, configs={"api_url": "test"}) + _client.connections.create_or_update(conn) + + inputs = {"text": "Hello World!"} + flow_path = Path(f"{FLOWS_DIR}/flow_with_script_tool_with_custom_strong_type_connection").absolute() + + # Test that connection would be custom strong type in flow + result = _client.test(flow=flow_path, inputs=inputs) + assert result == {"out": "connection_value is MyCustomConnection: True"} + + # Test node run + result = _client.test(flow=flow_path, inputs={"input_param": "Hello World!"}, node="my_script_tool") + assert result == "connection_value is MyCustomConnection: True" + def test_pf_test_with_streaming_output(self): flow_path = Path(f"{FLOWS_DIR}/chat_flow_with_stream_output") result = _client.test(flow=flow_path) diff --git a/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py b/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py index 3c92d86d610..99bfe26d8ec 100644 --- a/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py +++ b/src/promptflow/tests/sdk_cli_test/unittests/test_utils.py @@ -158,6 +158,15 @@ def test_generate_connections_dir(self, python_path, env_hash): result = _generate_connections_dir() assert result == expected_result + def test_refresh_connections_dir(self): + from promptflow._core.tools_manager import collect_package_tools_and_connections + + tools, specs, templates = collect_package_tools_and_connections() + + refresh_connections_dir(specs, templates) + conn_dir = _generate_connections_dir() + assert len(os.listdir(conn_dir)) > 0, "No files were generated" + @pytest.mark.parametrize("concurrent_count", [1, 2, 4, 8]) def test_concurrent_execution_of_refresh_connections_dir(self, concurrent_count): threads = [] diff --git a/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/data.jsonl b/src/promptflow/tests/test_configs/flows/flow_with_package_tool_with_custom_strong_type_connection/data.jsonl similarity index 100% rename from src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/data.jsonl rename to src/promptflow/tests/test_configs/flows/flow_with_package_tool_with_custom_strong_type_connection/data.jsonl diff --git a/src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/flow.dag.yaml b/src/promptflow/tests/test_configs/flows/flow_with_package_tool_with_custom_strong_type_connection/flow.dag.yaml similarity index 100% rename from src/promptflow/tests/test_configs/flows/custom_strong_type_connection_basic_flow/flow.dag.yaml rename to src/promptflow/tests/test_configs/flows/flow_with_package_tool_with_custom_strong_type_connection/flow.dag.yaml diff --git a/src/promptflow/tests/test_configs/flows/flow_with_script_tool_with_custom_strong_type_connection/data.jsonl b/src/promptflow/tests/test_configs/flows/flow_with_script_tool_with_custom_strong_type_connection/data.jsonl new file mode 100644 index 00000000000..15e3aa54262 --- /dev/null +++ b/src/promptflow/tests/test_configs/flows/flow_with_script_tool_with_custom_strong_type_connection/data.jsonl @@ -0,0 +1 @@ +{"text": "Hello World!"} diff --git a/src/promptflow/tests/test_configs/flows/flow_with_script_tool_with_custom_strong_type_connection/flow.dag.yaml b/src/promptflow/tests/test_configs/flows/flow_with_script_tool_with_custom_strong_type_connection/flow.dag.yaml new file mode 100644 index 00000000000..ea623a67a26 --- /dev/null +++ b/src/promptflow/tests/test_configs/flows/flow_with_script_tool_with_custom_strong_type_connection/flow.dag.yaml @@ -0,0 +1,17 @@ +inputs: + text: + type: string + default: this is an input +outputs: + out: + type: string + reference: ${my_script_tool.output} +nodes: +- name: my_script_tool + type: python + source: + type: code + path: my_script_tool.py + inputs: + connection: custom_connection_2 + input_param: ${inputs.text} diff --git a/src/promptflow/tests/test_configs/flows/flow_with_script_tool_with_custom_strong_type_connection/my_script_tool.py b/src/promptflow/tests/test_configs/flows/flow_with_script_tool_with_custom_strong_type_connection/my_script_tool.py new file mode 100644 index 00000000000..de06c5bffa5 --- /dev/null +++ b/src/promptflow/tests/test_configs/flows/flow_with_script_tool_with_custom_strong_type_connection/my_script_tool.py @@ -0,0 +1,22 @@ +from promptflow import tool +from promptflow.connections import CustomStrongTypeConnection, CustomConnection +from promptflow.contracts.types import Secret + + +class MyCustomConnection(CustomStrongTypeConnection): + """My custom strong type connection. + + :param api_key: The api key. + :type api_key: String + :param api_base: The api base. + :type api_base: String + """ + api_key: Secret + api_url: str = "This is a fake api url." + + +@tool +def my_tool(connection: MyCustomConnection, input_param: str) -> str: + # Replace with your tool code. + # Use custom strong type connection like: connection.api_key, connection.api_url + return f"connection_value is MyCustomConnection: {str(isinstance(connection, MyCustomConnection))}"