Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yalu4 committed Oct 16, 2023
1 parent a3b9f5d commit fadbff3
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 21 deletions.
4 changes: 2 additions & 2 deletions src/promptflow/promptflow/_core/tool_meta_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def collect_tool_methods_in_module(m):
return tools


def _parse_tool_from_function(f, should_gen_custom_type=False):
def _parse_tool_from_function(f, gen_custom_type_conn=False):
if hasattr(f, "__tool") and isinstance(f.__tool, Tool):
return f.__tool
if hasattr(f, "__original_function"):
f = f.__original_function
try:
inputs, _, _ = function_to_interface(f, should_gen_custom_type=should_gen_custom_type)
inputs, _, _ = function_to_interface(f, gen_custom_type_conn=gen_custom_type_conn)
except Exception as e:
error_type_and_message = f"({e.__class__.__name__}) {e}"
raise BadFunctionInterface(
Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/_core/tools_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def load_tool_for_script_node(self, node: Node) -> Tuple[types.ModuleType, Calla
if m is None:
raise CustomToolSourceLoadError(f"Cannot load module from {path}.")
f = collect_tool_function_in_module(m)
return m, f, _parse_tool_from_function(f, should_gen_custom_type=True)
return m, f, _parse_tool_from_function(f, gen_custom_type_conn=True)

def load_tool_for_llm_node(self, node: Node) -> Tool:
api_name = f"{node.provider}.{node.api}"
Expand Down
26 changes: 15 additions & 11 deletions src/promptflow/promptflow/_utils/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ def resolve_annotation(anno) -> Union[str, list]:
return args[0] if len(args) == 1 else args


def param_to_definition(param, should_gen_custom_type=False) -> (InputDefinition, bool):
def param_to_definition(param, gen_custom_type_conn=False) -> (InputDefinition, bool):
default_value = param.default
# Get value type and enum from annotation
value_type = resolve_annotation(param.annotation)
enum = None
custom_type = None
custom_type_conn = None
# Get value type and enum from default if no annotation
if default_value is not inspect.Parameter.empty and value_type == inspect.Parameter.empty:
value_type = default_value.__class__ if isinstance(default_value, Enum) else type(default_value)
Expand All @@ -54,7 +54,7 @@ def param_to_definition(param, should_gen_custom_type=False) -> (InputDefinition
if ConnectionType.is_connection_value(value_type):
if ConnectionType.is_custom_strong_type(value_type):
typ = ["CustomConnection"]
custom_type = [value_type.__name__]
custom_type_conn = [value_type.__name__]
else:
typ = [value_type.__name__]
is_connection = True
Expand All @@ -64,14 +64,14 @@ def param_to_definition(param, should_gen_custom_type=False) -> (InputDefinition
else:
custom_connection_added = False
typ = []
custom_type = []
custom_type_conn = []
for t in value_type:
# Add 'CustomConnection' to typ list when custom strong type connection exists. Collect all custom types
if ConnectionType.is_custom_strong_type(t):
if not custom_connection_added:
custom_connection_added = True
typ.append("CustomConnection")
custom_type.append(t.__name__)
custom_type_conn.append(t.__name__)
else:
if t.__name__ != "CustomConnection":
typ.append(t.__name__)
Expand All @@ -85,20 +85,24 @@ def param_to_definition(param, should_gen_custom_type=False) -> (InputDefinition
# 1. Do not generate custom type when generating flow.tools.json for script tool.
# Extension would show custom type if it exists. While for script tool with custom strong type connection,
# we still want to show 'CustomConnection' type.
# 2. Generate custom type when resolving tool in _tool_resolver, since we rely on the custom_type to convert the
# 2. Generate custom connection type when resolving tool in _tool_resolver, since we rely on it to convert the
# custom connection to custom strong type connection.
if not should_gen_custom_type:
custom_type = None
if not gen_custom_type_conn:
custom_type_conn = None

return (
InputDefinition(
type=typ, default=value_to_str(default_value), description=None, enum=enum, custom_type=custom_type
type=typ,
default=value_to_str(default_value),
description=None,
enum=enum,
custom_type=custom_type_conn,
),
is_connection,
)


def function_to_interface(f: Callable, initialize_inputs=None, should_gen_custom_type=False) -> tuple:
def function_to_interface(f: Callable, initialize_inputs=None, gen_custom_type_conn=False) -> tuple:
sign = inspect.signature(f)
all_inputs = {}
input_defs = {}
Expand All @@ -117,7 +121,7 @@ def function_to_interface(f: Callable, initialize_inputs=None, should_gen_custom
)
# Resolve inputs to definitions.
for k, v in all_inputs.items():
input_def, is_connection = param_to_definition(v, should_gen_custom_type=should_gen_custom_type)
input_def, is_connection = param_to_definition(v, gen_custom_type_conn=gen_custom_type_conn)
input_defs[k] = input_def
if is_connection:
connection_types.append(input_def.type)
Expand Down
3 changes: 3 additions & 0 deletions src/promptflow/promptflow/contracts/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ class InputDefinition:
default: str = None
description: str = None
enum: List[str] = None
# Param 'custom_type' is currently used for inputs of custom strong type connection.
# For a custom strong type connection input, the type should be 'CustomConnection',
# while the custom_type should be the custom strong type connection class name.
custom_type: List[str] = None

def serialize(self) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,30 +53,30 @@ def some_func(

sig = inspect.signature(some_func)

input_def, _ = param_to_definition(sig.parameters.get("conn1"), should_gen_custom_type=True)
input_def, _ = param_to_definition(sig.parameters.get("conn1"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type == ["MyFirstConnection"]

input_def, _ = param_to_definition(sig.parameters.get("conn2"), should_gen_custom_type=True)
input_def, _ = param_to_definition(sig.parameters.get("conn2"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type == ["MyFirstConnection"]

input_def, _ = param_to_definition(sig.parameters.get("conn3"), should_gen_custom_type=True)
input_def, _ = param_to_definition(sig.parameters.get("conn3"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type == ["MyFirstConnection"]

input_def, _ = param_to_definition(sig.parameters.get("conn4"), should_gen_custom_type=True)
input_def, _ = param_to_definition(sig.parameters.get("conn4"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type == ["MyFirstConnection", "MySecondConnection"]

input_def, _ = param_to_definition(sig.parameters.get("conn5"), should_gen_custom_type=True)
input_def, _ = param_to_definition(sig.parameters.get("conn5"), gen_custom_type_conn=True)
assert input_def.type == ["CustomConnection"]
assert input_def.custom_type is None

input_def, _ = param_to_definition(sig.parameters.get("conn6"), should_gen_custom_type=True)
input_def, _ = param_to_definition(sig.parameters.get("conn6"), gen_custom_type_conn=True)
assert input_def.type == [ValueType.OBJECT]
assert input_def.custom_type is None

input_def, _ = param_to_definition(sig.parameters.get("conn7"), should_gen_custom_type=True)
input_def, _ = param_to_definition(sig.parameters.get("conn7"), gen_custom_type_conn=True)
assert input_def.type == [ValueType.OBJECT]
assert input_def.custom_type is None
16 changes: 16 additions & 0 deletions src/promptflow/tests/sdk_cli_test/e2etests/test_flow_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,22 @@ def test_basic_flow_with_package_tool_with_custom_strong_type_connection(
run = local_client.runs.get(name=result.name)
assert run.status == "Completed"

def test_basic_flow_with_script_tool_with_custom_strong_type_connection(
self, install_custom_tool_pkg, local_client, pf
):
# Prepare custom connection
from promptflow.connections import CustomConnection

conn = CustomConnection(name="custom_connection_2", secrets={"api_key": "test"}, configs={"api_url": "test"})
local_client.connections.create_or_update(conn)

result = pf.run(
flow=f"{FLOWS_DIR}/flow_with_script_tool_with_custom_strong_type_connection",
data=f"{FLOWS_DIR}/flow_with_script_tool_with_custom_strong_type_connection/data.jsonl",
)
run = local_client.runs.get(name=result.name)
assert run.status == "Completed"

def test_run_with_connection_overwrite_non_exist(self, local_client, local_aoai_connection, pf):
# overwrite non_exist connection
with pytest.raises(Exception) as e:
Expand Down
9 changes: 9 additions & 0 deletions src/promptflow/tests/sdk_cli_test/unittests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ def test_generate_connections_dir(self, python_path, env_hash):
result = _generate_connections_dir()
assert result == expected_result

def test_refresh_connections_dir(self):
from promptflow._core.tools_manager import collect_package_tools_and_connections

tools, specs, templates = collect_package_tools_and_connections()

refresh_connections_dir(specs, templates)
conn_dir = _generate_connections_dir()
assert len(os.listdir(conn_dir)) > 0, "No files were generated"

@pytest.mark.parametrize("concurrent_count", [1, 2, 4, 8])
def test_concurrent_execution_of_refresh_connections_dir(self, concurrent_count):
threads = []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"text": "Hello World!"}

0 comments on commit fadbff3

Please sign in to comment.