Skip to content

Commit

Permalink
Support custom strong type connection in script tool (Simplified) (#733)
Browse files Browse the repository at this point in the history
# Description
doc: [(Simplify) Support custom strong type connection in script
tool.docx](https://microsoftapc-my.sharepoint.com/:w:/g/personal/yalu4_microsoft_com/EYUiMpo5kK5Cj7vJygKEqZQBtl28ZZiWQnljvW9URimFDA?e=kE30rg)

Support custom strong type connection in script tool.
1. User can define and use their own custom strong type conenction in
their script tool:

![image](https://github.com/microsoft/promptflow/assets/46446115/91382a1d-0b6b-460d-876a-e051574e1d8e)
2. The connection type would be shown as custom connection in flow

![image](https://github.com/microsoft/promptflow/assets/46446115/a7c5e6ff-e731-4a38-a3fa-54f53f594c6b)
4. The connection value is the a normal custom connection.
5. The way of creating this custom connection is no different from the
usual of a CustomConnection. User needs to fill in the key value of
their custom defined connection.

![image](https://github.com/microsoft/promptflow/assets/46446115/d8d41822-c8d9-4b9e-85e9-35d16c930f7c)


local to cloud test command:
_preparation:
install [promptflow package with supported
feature](https://msdata.visualstudio.com/Vienna/_build/results?buildId=107506794&view=results)
in test runtime._
```
pfazure run create --subscription 96aede12-2f73-41cb-b983-6d11a904839b -g promptflow -w chjinche-pf-eus --flow D:\testscripts\test_pf_cmd\tests\new-empty-flow-created-at-2023-10-10 --data D:\testscripts\test_pf_cmd\tests\new-empty-flow-created-at-2023-10-10\data.jsonl --runtime test-compute
```
Run success link:
https://ml.azure.com/prompts/flow/4d49a4bb-8594-4ccd-842e-73a9aeb3fcb1/27ebca2e-8a12-40a9-9f01-bd1ca72815c2/details?wsid=/subscriptions/96aede12-2f73-41cb-b983-6d11a904839b/resourcegroups/promptflow/providers/Microsoft.MachineLearningServices/workspaces/chjinche-pf-eus&tid=72f988bf-86f1-41af-91ab-2d7cd011db47

---------

Co-authored-by: yalu4 <[email protected]>
  • Loading branch information
16oeahr and yalu4 authored Oct 16, 2023
1 parent 696c145 commit 2aa2ef8
Show file tree
Hide file tree
Showing 20 changed files with 349 additions and 46 deletions.
4 changes: 2 additions & 2 deletions src/promptflow/promptflow/_core/tool_meta_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def collect_tool_methods_in_module(m):
return tools


def _parse_tool_from_function(f):
def _parse_tool_from_function(f, gen_custom_type_conn=False):
if hasattr(f, "__tool") and isinstance(f.__tool, Tool):
return f.__tool
if hasattr(f, "__original_function"):
f = f.__original_function
try:
inputs, _, _ = function_to_interface(f)
inputs, _, _ = function_to_interface(f, gen_custom_type_conn=gen_custom_type_conn)
except Exception as e:
error_type_and_message = f"({e.__class__.__name__}) {e}"
raise BadFunctionInterface(
Expand Down
7 changes: 4 additions & 3 deletions src/promptflow/promptflow/_core/tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import inspect
import logging
import traceback
import types
from functools import partial
from pathlib import Path
from typing import Callable, List, Mapping, Optional, Tuple, Union
Expand Down Expand Up @@ -347,7 +348,7 @@ def load_tool_for_node(self, node: Node) -> Tool:
if node.source.type == ToolSourceType.Package:
return self.load_tool_for_package_node(node)
elif node.source.type == ToolSourceType.Code:
_, tool = self.load_tool_for_script_node(node)
_, _, tool = self.load_tool_for_script_node(node)
return tool
raise NotImplementedError(f"Tool source type {node.source.type} for python tool is not supported yet.")
elif node.type == ToolType.CUSTOM_LLM:
Expand All @@ -366,15 +367,15 @@ def load_tool_for_package_node(self, node: Node) -> Tool:
target=ErrorTarget.EXECUTOR,
)

def load_tool_for_script_node(self, node: Node) -> Tuple[Callable, Tool]:
def load_tool_for_script_node(self, node: Node) -> Tuple[types.ModuleType, Callable, Tool]:
if node.source.path is None:
raise UserErrorException(f"Node {node.name} does not have source path defined.")
path = node.source.path
m = load_python_module_from_file(self._working_dir / path)
if m is None:
raise CustomToolSourceLoadError(f"Cannot load module from {path}.")
f = collect_tool_function_in_module(m)
return f, _parse_tool_from_function(f)
return m, f, _parse_tool_from_function(f, gen_custom_type_conn=True)

def load_tool_for_llm_node(self, node: Node) -> Tool:
api_name = f"{node.provider}.{node.api}"
Expand Down
29 changes: 22 additions & 7 deletions src/promptflow/promptflow/_sdk/entities/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def _convert_to_custom(self):
return custom_connection

@classmethod
def _get_custom_keys(cls, data):
def _get_custom_keys(cls, data: Dict):
# The data could be either from yaml or from DB.
# If from yaml, 'custom_type' and 'module' are outside the configs of data.
# If from DB, 'custom_type' and 'module' are within the configs of data.
Expand Down Expand Up @@ -933,13 +933,28 @@ def _is_custom_strong_type(self):
and self.configs[CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY]
)

def _convert_to_custom_strong_type(self):
module_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY)
custom_type_class_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY)
import importlib
def _convert_to_custom_strong_type(self, to_class) -> CustomStrongTypeConnection:
# There are two scenarios to convert a custom connection to custom strong type connection:
# 1. The connection is created from a custom strong type connection template file.
# Custom type and module name are present in the configs.
# 2. The connection is created through SDK PFClient or a custom connection template file.
# Custom type and module name are not present in the configs. Module and class must be passed for conversion.
if to_class and not isinstance(to_class, type):
raise TypeError(f"The converted type {to_class} must be a class type.")

module = importlib.import_module(module_name)
custom_defined_connection_class = getattr(module, custom_type_class_name)
if not to_class:
module_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY)
import importlib

module = importlib.import_module(module_name)
custom_type_class_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY)
custom_defined_connection_class = getattr(module, custom_type_class_name)

if to_class and issubclass(to_class, CustomConnection):
# No need to convert.
return self

custom_defined_connection_class = to_class or custom_defined_connection_class
connection_instance = custom_defined_connection_class(configs=self.configs, secrets=self.secrets)

return connection_instance
Expand Down
50 changes: 44 additions & 6 deletions src/promptflow/promptflow/_utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ def resolve_annotation(anno) -> Union[str, list]:
return args[0] if len(args) == 1 else args


def param_to_definition(param) -> (InputDefinition, bool):
def param_to_definition(param, gen_custom_type_conn=False) -> (InputDefinition, bool):
default_value = param.default
# Get value type and enum from annotation
value_type = resolve_annotation(param.annotation)
enum = None
custom_type_conn = None
# Get value type and enum from default if no annotation
if default_value is not inspect.Parameter.empty and value_type == inspect.Parameter.empty:
value_type = default_value.__class__ if isinstance(default_value, Enum) else type(default_value)
Expand All @@ -51,20 +52,57 @@ def param_to_definition(param) -> (InputDefinition, bool):
value_type = str
is_connection = False
if ConnectionType.is_connection_value(value_type):
typ = [value_type.__name__]
if ConnectionType.is_custom_strong_type(value_type):
typ = ["CustomConnection"]
custom_type_conn = [value_type.__name__]
else:
typ = [value_type.__name__]
is_connection = True
elif isinstance(value_type, list):
if not all(ConnectionType.is_connection_value(t) for t in value_type):
typ = [ValueType.OBJECT]
else:
typ = [t.__name__ for t in value_type]
custom_connection_added = False
typ = []
custom_type_conn = []
for t in value_type:
# Add 'CustomConnection' to typ list when custom strong type connection exists. Collect all custom types
if ConnectionType.is_custom_strong_type(t):
if not custom_connection_added:
custom_connection_added = True
typ.append("CustomConnection")
custom_type_conn.append(t.__name__)
else:
if t.__name__ != "CustomConnection":
typ.append(t.__name__)
elif not custom_connection_added:
custom_connection_added = True
typ.append(t.__name__)
is_connection = True
else:
typ = [ValueType.from_type(value_type)]
return InputDefinition(type=typ, default=value_to_str(default_value), description=None, enum=enum), is_connection

# 1. Do not generate custom type when generating flow.tools.json for script tool.
# Extension would show custom type if it exists. While for script tool with custom strong type connection,
# we still want to show 'CustomConnection' type.
# 2. Generate custom connection type when resolving tool in _tool_resolver, since we rely on it to convert the
# custom connection to custom strong type connection.
if not gen_custom_type_conn:
custom_type_conn = None

return (
InputDefinition(
type=typ,
default=value_to_str(default_value),
description=None,
enum=enum,
custom_type=custom_type_conn,
),
is_connection,
)


def function_to_interface(f: Callable, initialize_inputs=None) -> tuple:
def function_to_interface(f: Callable, initialize_inputs=None, gen_custom_type_conn=False) -> tuple:
sign = inspect.signature(f)
all_inputs = {}
input_defs = {}
Expand All @@ -83,7 +121,7 @@ def function_to_interface(f: Callable, initialize_inputs=None) -> tuple:
)
# Resolve inputs to definitions.
for k, v in all_inputs.items():
input_def, is_connection = param_to_definition(v)
input_def, is_connection = param_to_definition(v, gen_custom_type_conn=gen_custom_type_conn)
input_defs[k] = input_def
if is_connection:
connection_types.append(input_def.type)
Expand Down
4 changes: 4 additions & 0 deletions src/promptflow/promptflow/contracts/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ class InputDefinition:
default: str = None
description: str = None
enum: List[str] = None
# Param 'custom_type' is currently used for inputs of custom strong type connection.
# For a custom strong type connection input, the type should be 'CustomConnection',
# while the custom_type should be the custom strong type connection class name.
custom_type: List[str] = None

def serialize(self) -> dict:
Expand Down Expand Up @@ -285,6 +288,7 @@ def _deserialize_type(v):
data.get("default", ""),
data.get("description", ""),
data.get("enum", []),
data.get("custom_type", []),
)


Expand Down
43 changes: 34 additions & 9 deletions src/promptflow/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import copy
import inspect
import types
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, List, Optional

from promptflow._core.connection_manager import ConnectionManager
from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping
from promptflow._sdk.entities import CustomConnection
from promptflow._utils.tool_utils import get_inputs_for_prompt_template, get_prompt_param_name_from_func
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType
from promptflow.contracts.tool import ConnectionType, Tool, ToolType, ValueType
Expand Down Expand Up @@ -53,10 +53,6 @@ def _convert_to_connection_value(self, k: str, v: InputAssignment, node: Node, c
connection_value = self._connection_manager.get(v.value)
if not connection_value:
raise ConnectionNotFound(f"Connection {v.value} not found for node {node.name!r} input {k!r}.")

if isinstance(connection_value, CustomConnection) and connection_value._is_custom_strong_type():
return connection_value._convert_to_custom_strong_type()

# Check if type matched
if not any(type(connection_value).__name__ == typ for typ in conn_types):
msg = (
Expand All @@ -66,7 +62,23 @@ def _convert_to_connection_value(self, k: str, v: InputAssignment, node: Node, c
raise NodeInputValidationError(message=msg)
return connection_value

def _convert_node_literal_input_types(self, node: Node, tool: Tool):
def _convert_to_custom_strong_type_connection_value(
self, k: str, v: InputAssignment, node: Node, conn_types: List[str], module: types.ModuleType
):
if conn_types is None:
msg = f"Input '{k}' for node '{node.name}' has invalid types: None."
raise NodeInputValidationError(message=msg)
connection_value = self._connection_manager.get(v.value)
if not connection_value:
raise ConnectionNotFound(f"Connection {v.value} not found for node {node.name!r} input {k!r}.")

custom_defined_connection_class = None
if node.source.type == ToolSourceType.Code:
custom_type_class_name = conn_types[0]
custom_defined_connection_class = getattr(module, custom_type_class_name)
return connection_value._convert_to_custom_strong_type(to_class=custom_defined_connection_class)

def _convert_node_literal_input_types(self, node: Node, tool: Tool, module: types.ModuleType = None):
updated_inputs = {
k: v
for k, v in node.inputs.items()
Expand All @@ -81,7 +93,12 @@ def _convert_node_literal_input_types(self, node: Node, tool: Tool):
value_type = tool_input.type[0]
updated_inputs[k] = InputAssignment(value=v.value, value_type=InputValueType.LITERAL)
if ConnectionType.is_connection_class_name(value_type):
updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type)
if tool_input.custom_type:
updated_inputs[k].value = self._convert_to_custom_strong_type_connection_value(
k, v, node, tool_input.custom_type, module=module
)
else:
updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type)
elif isinstance(value_type, ValueType):
try:
updated_inputs[k].value = value_type.parse(v.value)
Expand Down Expand Up @@ -221,9 +238,17 @@ def _resolve_llm_connection_to_inputs(self, node: Node, tool: Tool) -> Node:
)

def _resolve_script_node(self, node: Node, convert_input_types=False) -> ResolvedTool:
f, tool = self._tool_loader.load_tool_for_script_node(node)
m, f, tool = self._tool_loader.load_tool_for_script_node(node)
# We only want to load script tool module once.
# Reloading the same module changes the ID of the class, which can cause issues with isinstance() checks.
# This is important when working with connection class checks. For instance, in user tool script it writes:
# isinstance(conn, MyCustomConnection)
# Custom defined script tool and custom defined strong type connection are in the same module.
# The first time to load the module is in above line when loading a tool.
# We need the module again when converting the custom connection to strong type when converting input types.
# To avoid reloading, pass the loaded module to _convert_node_literal_input_types as an arg.
if convert_input_types:
node = self._convert_node_literal_input_types(node, tool)
node = self._convert_node_literal_input_types(node, tool, m)
return ResolvedTool(node=node, definition=tool, callable=f, init_args={})

def _resolve_package_node(self, node: Node, convert_input_types=False) -> ResolvedTool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,6 @@ def test_gen_tool_by_source_error(self, tool_source, tool_type, error_code, erro
gen_tool_by_source("fake_name", tool_source, tool_type, working_dir),
assert str(ex.value) == error_message

def test(self):
tools, specs, templates = collect_package_tools_and_connections()
from promptflow._sdk._utils import refresh_connections_dir

refresh_connections_dir(specs, templates)

def test_collect_package_tools_and_connections(self, install_custom_tool_pkg):
# Need to reload pkg_resources to get the latest installed tools
import importlib
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import inspect
from typing import Union

import pytest

from promptflow._utils.tool_utils import function_to_interface
from promptflow._utils.tool_utils import function_to_interface, param_to_definition
from promptflow.connections import AzureOpenAIConnection, CustomConnection
from promptflow.contracts.tool import ValueType

Expand All @@ -24,3 +27,56 @@ def func(input_str: str):
with pytest.raises(Exception) as exec_info:
function_to_interface(func, {"input_str": "test"})
assert "Duplicate inputs found from" in exec_info.value.args[0]

def test_param_to_definition(self):
from promptflow._sdk.entities import CustomStrongTypeConnection
from promptflow.contracts.tool import Secret

class MyFirstConnection(CustomStrongTypeConnection):
api_key: Secret
api_base: str

class MySecondConnection(CustomStrongTypeConnection):
api_key: Secret
api_base: str

def some_func(
conn1: MyFirstConnection,
conn2: Union[CustomConnection, MyFirstConnection],
conn3: Union[MyFirstConnection, CustomConnection],
conn4: Union[MyFirstConnection, MySecondConnection],
conn5: CustomConnection,
conn6: Union[CustomConnection, int],
conn7: Union[MyFirstConnection, int],
):
pass

sig = inspect.signature(some_func)

input_def, _ = param_to_definition(sig.parameters.get("conn1"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type == ["MyFirstConnection"]

input_def, _ = param_to_definition(sig.parameters.get("conn2"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type == ["MyFirstConnection"]

input_def, _ = param_to_definition(sig.parameters.get("conn3"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type == ["MyFirstConnection"]

input_def, _ = param_to_definition(sig.parameters.get("conn4"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type == ["MyFirstConnection", "MySecondConnection"]

input_def, _ = param_to_definition(sig.parameters.get("conn5"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type is None

input_def, _ = param_to_definition(sig.parameters.get("conn6"), gen_custom_type_conn=True)
assert input_def.type == [ValueType.OBJECT]
assert input_def.custom_type is None

input_def, _ = param_to_definition(sig.parameters.get("conn7"), gen_custom_type_conn=True)
assert input_def.type == [ValueType.OBJECT]
assert input_def.custom_type is None
Loading

0 comments on commit 2aa2ef8

Please sign in to comment.