Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
yalu4 committed Oct 12, 2023
1 parent cb7788e commit 29aa602
Show file tree
Hide file tree
Showing 15 changed files with 203 additions and 35 deletions.
15 changes: 12 additions & 3 deletions src/promptflow/promptflow/_core/tool_meta_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import inspect
import json
import re
import sys
import types
from dataclasses import asdict
from pathlib import Path
Expand Down Expand Up @@ -111,13 +112,13 @@ def collect_tool_methods_in_module(m):
return tools


def _parse_tool_from_function(f):
def _parse_tool_from_function(f, should_gen_custom_type=False):
if hasattr(f, "__tool") and isinstance(f.__tool, Tool):
return f.__tool
if hasattr(f, "__original_function"):
f = f.__original_function
try:
inputs, _, _ = function_to_interface(f)
inputs, _, _ = function_to_interface(f, should_gen_custom_type=should_gen_custom_type)
except Exception as e:
raise BadFunctionInterface(f"Failed to parse interface for tool {f.__name__}, reason: {e}") from e
class_name = None
Expand Down Expand Up @@ -149,7 +150,14 @@ def generate_python_tools_in_module_as_dict(module):
def load_python_module_from_file(src_file: Path):
# Here we hard code the module name as __pf_main__ since it is invoked as a main script in pf.
src_file = Path(src_file).resolve() # Make sure the path is absolute to align with python import behavior.
spec = importlib.util.spec_from_file_location("__pf_main__", location=src_file)
module_name = "__pf_main__"

# If the same module gets reloaded, the id of the class changes, leading to issues with isinstance() checks.
# To avoid this, return the already loaded module if it exists in sys.modules instead of reloading it.
if module_name in sys.modules:
return sys.modules[module_name]

spec = importlib.util.spec_from_file_location(module_name, location=src_file)
if spec is None or spec.loader is None:
raise PythonLoaderNotFound(f"Failed to load python file '{src_file}', please make sure it is a valid .py file.")
m = importlib.util.module_from_spec(spec)
Expand All @@ -158,6 +166,7 @@ def load_python_module_from_file(src_file: Path):
except Exception as e:
# TODO: add stacktrace to additional info
raise PythonLoadError(f"Failed to load python module from file '{src_file}', reason: {e}.") from e
sys.modules[module_name] = m
return m


Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/_core/tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def load_tool_for_script_node(self, node: Node) -> Tuple[Callable, Tool]:
if m is None:
raise CustomToolSourceLoadError(f"Cannot load module from {path}.")
f = collect_tool_function_in_module(m)
return f, _parse_tool_from_function(f)
return f, _parse_tool_from_function(f, should_gen_custom_type=True)

def load_tool_for_llm_node(self, node: Node) -> Tool:
api_name = f"{node.provider}.{node.api}"
Expand Down
12 changes: 7 additions & 5 deletions src/promptflow/promptflow/_sdk/entities/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,12 +933,14 @@ def _is_custom_strong_type(self):
and self.configs[CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY]
)

def _convert_to_custom_strong_type(self):
module_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY)
custom_type_class_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY)
import importlib
def _convert_to_custom_strong_type(self, module, class_name):
if not module:
module_name = self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_MODULE_KEY)
import importlib

module = importlib.import_module(module_name)
module = importlib.import_module(module_name)

custom_type_class_name = class_name or self.configs.get(CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY)
custom_defined_connection_class = getattr(module, custom_type_class_name)
connection_instance = custom_defined_connection_class(configs=self.configs, secrets=self.secrets)

Expand Down
34 changes: 28 additions & 6 deletions src/promptflow/promptflow/_utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ def resolve_annotation(anno) -> Union[str, list]:
return args[0] if len(args) == 1 else args


def param_to_definition(param) -> (InputDefinition, bool):
def param_to_definition(param, should_gen_custom_type=False) -> (InputDefinition, bool):
default_value = param.default
# Get value type and enum from annotation
value_type = resolve_annotation(param.annotation)
enum = None
custom_type = None
# Get value type and enum from default if no annotation
if default_value is not inspect.Parameter.empty and value_type == inspect.Parameter.empty:
value_type = default_value.__class__ if isinstance(default_value, Enum) else type(default_value)
Expand All @@ -51,20 +52,41 @@ def param_to_definition(param) -> (InputDefinition, bool):
value_type = str
is_connection = False
if ConnectionType.is_connection_value(value_type):
typ = [value_type.__name__]
if ConnectionType.is_custom_strong_type(value_type):
typ = ["CustomConnection"]
custom_type = [value_type.__name__]
else:
typ = [value_type.__name__]
is_connection = True
elif isinstance(value_type, list):
if not all(ConnectionType.is_connection_value(t) for t in value_type):
typ = [ValueType.OBJECT]
else:
typ = [t.__name__ for t in value_type]
custom_connection_added = False
typ = []
custom_type = []
for t in value_type:
if ConnectionType.is_custom_strong_type(t):
if not custom_connection_added:
custom_connection_added = True
typ.append("CustomConnection")
custom_type.append(t.__name__)
else:
typ.append(t.__name__)
is_connection = True
else:
typ = [ValueType.from_type(value_type)]
return InputDefinition(type=typ, default=value_to_str(default_value), description=None, enum=enum), is_connection
if not should_gen_custom_type:
custom_type = None
return (
InputDefinition(
type=typ, default=value_to_str(default_value), description=None, enum=enum, custom_type=custom_type
),
is_connection,
)


def function_to_interface(f: Callable, initialize_inputs=None) -> tuple:
def function_to_interface(f: Callable, initialize_inputs=None, should_gen_custom_type=False) -> tuple:
sign = inspect.signature(f)
all_inputs = {}
input_defs = {}
Expand All @@ -83,7 +105,7 @@ def function_to_interface(f: Callable, initialize_inputs=None) -> tuple:
)
# Resolve inputs to definitions.
for k, v in all_inputs.items():
input_def, is_connection = param_to_definition(v)
input_def, is_connection = param_to_definition(v, should_gen_custom_type=should_gen_custom_type)
input_defs[k] = input_def
if is_connection:
connection_types.append(input_def.type)
Expand Down
20 changes: 15 additions & 5 deletions src/promptflow/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from typing import Callable, List, Optional

from promptflow._core.connection_manager import ConnectionManager
from promptflow._core.tool_meta_generator import load_python_module_from_file
from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping
from promptflow._sdk.entities import CustomConnection
from promptflow._utils.tool_utils import get_inputs_for_prompt_template, get_prompt_param_name_from_func
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSourceType
from promptflow.contracts.tool import ConnectionType, Tool, ToolType, ValueType
Expand Down Expand Up @@ -49,13 +49,15 @@ def __init__(
self._working_dir = working_dir
self._connection_manager = ConnectionManager(connections)

def _convert_to_connection_value(self, k: str, v: InputAssignment, node: Node, conn_types: List[ValueType]):
def _convert_to_connection_value(
self, k: str, v: InputAssignment, node: Node, conn_types: List[ValueType], should_convert=False, module=None
):
connection_value = self._connection_manager.get(v.value)
if not connection_value:
raise ConnectionNotFound(f"Connection {v.value} not found for node {node.name!r} input {k!r}.")

if isinstance(connection_value, CustomConnection) and connection_value._is_custom_strong_type():
return connection_value._convert_to_custom_strong_type()
if should_convert:
return connection_value._convert_to_custom_strong_type(module, conn_types[0])

# Check if type matched
if not any(type(connection_value).__name__ == typ for typ in conn_types):
Expand All @@ -81,7 +83,15 @@ def _convert_node_literal_input_types(self, node: Node, tool: Tool):
value_type = tool_input.type[0]
updated_inputs[k] = InputAssignment(value=v.value, value_type=InputValueType.LITERAL)
if ConnectionType.is_connection_class_name(value_type):
updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type)
if tool_input.custom_type:
custom_m = None
if node.type == ToolType.PYTHON:
custom_m = load_python_module_from_file(node.source.path)
updated_inputs[k].value = self._convert_to_connection_value(
k, v, node, tool_input.custom_type, should_convert=True, module=custom_m
)
else:
updated_inputs[k].value = self._convert_to_connection_value(k, v, node, tool_input.type)
elif isinstance(value_type, ValueType):
try:
updated_inputs[k].value = value_type.parse(v.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,6 @@ def test_gen_tool_by_source_error(self, tool_source, tool_type, error_code, erro
gen_tool_by_source("fake_name", tool_source, tool_type, working_dir),
assert str(ex.value) == error_message

def test(self):
tools, specs, templates = collect_package_tools_and_connections()
from promptflow._sdk._utils import refresh_connections_dir

refresh_connections_dir(specs, templates)

def test_collect_package_tools_and_connections(self, install_custom_tool_pkg):
# Need to reload pkg_resources to get the latest installed tools
import importlib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,26 @@ def mock_package_func(prompt: PromptTemplate, **kwargs):
resolved_tool = tool_resolver._integrate_prompt_in_package_node(node, resolved_tool)
kwargs = {k: v.value for k, v in resolved_tool.node.inputs.items()}
assert resolved_tool.callable(**kwargs) == "Hello World!"

@pytest.mark.parametrize(
"conn_types",
"expected_res",
[
(None, None)
# ([CustomConnection, MyFirstCSTConnection], None),
# ([CustomConnection, MyFirstCSTConnection, MySecondCSTConnection], None),
],
)
def test_convert_to_connection_value(self, mocker, conn_types):
connections = None
tool_resolver = ToolResolver(working_dir=None, connections=connections)
# For custom strong type, need to consider the conn_types as:
# 1. conn_types is None
# 2. conn_types is a list of custom strong type
# a. [CustomConnection, MyFirstCSTConnection]
# b. [CustomConnection, MyFirstCSTConnection, MySecondCSTConnection]
# c. [MyFirstCSTConnection, MySecondCSTConnection]
# d. [MyFirstCSTConnection]
conn_types = None
tool_resolver._convert_to_connection_value("conn_name", None, None, conn_types)
raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ def test_basic_flow_run_bulk_without_env(self, pf, runtime) -> None:
assert isinstance(run, Run)

@pytest.mark.skip("Custom tool pkg and promptprompt pkg with CustomStrongTypeConnection not installed on runtime.")
def test_basic_flow_run_with_custom_strong_type_connection(self, pf, runtime) -> None:
def test_basic_flow_with_package_tool_with_custom_strong_type_connection(self, pf, runtime) -> None:
name = str(uuid.uuid4())
run_pf_command(
"run",
"create",
"--flow",
f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow",
f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection",
"--data",
f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow/data.jsonl",
f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection/data.jsonl",
"--name",
name,
pf=pf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,58 @@ def test_flow_generate_tools_meta(self, pf) -> None:
"package": {},
}
assert tools_error == {}

def test_flow_generate_tools_meta_with_pkg_tool_with_custom_strong_type_connection(self, pf) -> None:
source = f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection"

tools_meta, tools_error = pf.flows._generate_tools_meta(source)

assert tools_error == {}
assert tools_meta["code"] == {}
assert tools_meta["package"] == {
"my_tool_package.tools.my_tool_1.my_tool": {
"function": "my_tool",
"inputs": {
"connection": {
"type": ["CustomConnection"],
"custom_type": ["MyFirstConnection", "MySecondConnection"],
},
"input_text": {"type": ["string"]},
},
"module": "my_tool_package.tools.my_tool_1",
"name": "My First Tool",
"description": "This is my first tool",
"type": "python",
"package": "test-custom-tools",
"package_version": "0.0.2",
},
"my_tool_package.tools.my_tool_2.MyTool.my_tool": {
"class_name": "MyTool",
"function": "my_tool",
"inputs": {
"connection": {"type": ["CustomConnection"], "custom_type": ["MySecondConnection"]},
"input_text": {"type": ["string"]},
},
"module": "my_tool_package.tools.my_tool_2",
"name": "My Second Tool",
"description": "This is my second tool",
"type": "python",
"package": "test-custom-tools",
"package_version": "0.0.2",
},
}

def test_flow_generate_tools_meta_with_script_tool_with_custom_strong_type_connection(self, pf) -> None:
source = f"{FLOWS_DIR}/flow_with_script_tool_with_custom_strong_type_connection"

tools_meta, tools_error = pf.flows._generate_tools_meta(source)
assert tools_error == {}
assert tools_meta["package"] == {}
assert tools_meta["code"] == {
"my_script_tool.py": {
"function": "my_tool",
"inputs": {"connection": {"type": ["CustomConnection"]}, "input_param": {"type": ["string"]}},
"source": "my_script_tool.py",
"type": "python",
}
}
8 changes: 5 additions & 3 deletions src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ def test_custom_connection_overwrite(self, local_client, local_custom_connection
)
assert "Connection with name new_connection not found" in str(e.value)

def test_custom_strong_type_connection_basic_flow(self, install_custom_tool_pkg, local_client, pf):
def test_basic_flow_with_package_tool_with_custom_strong_type_connection(
self, install_custom_tool_pkg, local_client, pf
):
# Need to reload pkg_resources to get the latest installed tools
import importlib

Expand All @@ -256,8 +258,8 @@ def test_custom_strong_type_connection_basic_flow(self, install_custom_tool_pkg,
importlib.reload(pkg_resources)

result = pf.run(
flow=f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow",
data=f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow/data.jsonl",
flow=f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection",
data=f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection/data.jsonl",
connections={"My_First_Tool_00f8": {"connection": "custom_strong_type_connection"}},
)
run = local_client.runs.get(name=result.name)
Expand Down
18 changes: 15 additions & 3 deletions src/promptflow/tests/sdk_cli_test/e2etests/test_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_pf_test_flow(self):
result = _client.test(flow=f"{FLOWS_DIR}/web_classification")
assert all([key in FLOW_RESULT_KEYS for key in result])

def test_pf_test_flow_with_custom_strong_type_connection(self, install_custom_tool_pkg):
def test_pf_test_flow_with_package_tool_with_custom_strong_type_connection(self, install_custom_tool_pkg):
# Need to reload pkg_resources to get the latest installed tools
import importlib

Expand All @@ -42,16 +42,28 @@ def test_pf_test_flow_with_custom_strong_type_connection(self, install_custom_to
importlib.reload(pkg_resources)

inputs = {"text": "Hello World!"}
flow_path = Path(f"{FLOWS_DIR}/custom_strong_type_connection_basic_flow").absolute()
flow_path = Path(f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection").absolute()

# Test that connection would be custom strong type in flow
result = _client.test(flow=flow_path, inputs=inputs)
assert result == {"out": "connection_value is MyFirstConnection: True"}

# Test that connection
# Test node run
result = _client.test(flow=flow_path, inputs={"input_text": "Hello World!"}, node="My_Second_Tool_usi3")
assert result == "Hello World!This is my first custom connection."

def test_pf_test_flow_with_script_tool_with_custom_strong_type_connection(self):
inputs = {"text": "Hello World!"}
flow_path = Path(f"{FLOWS_DIR}/flow_with_script_tool_with_custom_strong_type_connection").absolute()

# Test that connection would be custom strong type in flow
result = _client.test(flow=flow_path, inputs=inputs)
assert result == {"out": "connection_value is MyCustomConnection: True"}

# Test node run
result = _client.test(flow=flow_path, inputs={"input_param": "Hello World!"}, node="my_script_tool")
assert result == "connection_value is MyCustomConnection: True"

def test_pf_test_with_streaming_output(self):
flow_path = Path(f"{FLOWS_DIR}/chat_flow_with_stream_output")
result = _client.test(flow=flow_path)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
inputs:
text:
type: string
default: this is an input
outputs:
out:
type: string
reference: ${my_script_tool.output}
nodes:
- name: my_script_tool
type: python
source:
type: code
path: my_script_tool.py
inputs:
connection: custom_connection_2
input_param: ${inputs.text}
Loading

0 comments on commit 29aa602

Please sign in to comment.