Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tool meta generation perf optimization #1250

Merged
merged 4 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 116 additions & 50 deletions src/promptflow/promptflow/_sdk/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import collections
import hashlib
import json
Expand All @@ -18,7 +17,7 @@
from enum import Enum
from os import PathLike
from pathlib import Path
from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Union
from typing import IO, Any, AnyStr, Dict, List, Optional, Set, Tuple, Union
from urllib.parse import urlparse

import keyring
Expand Down Expand Up @@ -589,20 +588,46 @@ def _generate_tool_meta(
timeout: int,
*,
include_errors_in_output: bool = False,
load_in_subprocess: bool = True,
elliotzh marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[str, dict]:
"""Generate tool meta from files.

:param flow_directory: flow directory
:param tools: tool list
:param raise_error: whether raise error when generate meta failed
:param timeout: timeout for generate meta
:param include_errors_in_output: whether include errors in output
:param load_in_subprocess: whether load tool meta with subprocess to prevent system path disturb. Default is True.
If set to False, will load tool meta in sync mode and timeout need to be handled outside current process.
:return: tool meta dict
"""
logger = LoggerFactory.get_logger(LOGGER_NAME)
# use multi process generate to avoid system path disturb
manager = multiprocessing.Manager()
tools_dict = manager.dict()
exception_dict = manager.dict()
p = multiprocessing.Process(
target=_generate_meta_from_files, args=(tools, flow_directory, tools_dict, exception_dict)
)
p.start()
p.join(timeout=timeout)
if p.is_alive():
p.terminate()
p.join()
if load_in_subprocess:
# use multiprocess generate to avoid system path disturb
manager = multiprocessing.Manager()
tools_dict = manager.dict()
exception_dict = manager.dict()
p = multiprocessing.Process(
target=_generate_meta_from_files, args=(tools, flow_directory, tools_dict, exception_dict)
)
p.start()
p.join(timeout=timeout)
if p.is_alive():
logger.warning(f"Generate meta timeout after {timeout} seconds, terminate the process.")
p.terminate()
p.join()
else:
tools_dict, exception_dict = {}, {}

# There is no built-in method to forcefully stop a running thread/coroutine in Python
# because abruptly stopping a thread can cause issues like resource leaks,
# deadlocks, or inconsistent states.
# Caller needs to handle the timeout outside current process.
logger.warning(
"Generate meta in current process and timeout won't take effect. "
"Please handle timeout manually outside current process."
)
_generate_meta_from_files(tools, flow_directory, tools_dict, exception_dict)
elliotzh marked this conversation as resolved.
Show resolved Hide resolved
res = {source: tool for source, tool in tools_dict.items()}

for source in res:
Expand Down Expand Up @@ -695,6 +720,68 @@ def _generate_package_tools(keys: Optional[List[str]] = None) -> dict:
return collect_package_tools(keys=keys)


def _update_involved_tools_and_packages(
_node,
_node_path,
*,
tools: List,
used_packages: Set,
source_path_mapping: Dict[str, List[str]],
):
source, tool_type = pydash.get(_node, "source.path", None), _node.get("type", None)

used_packages.add(pydash.get(_node, "source.tool", None))

if source is None or tool_type is None:
return

# for custom LLM tool, its source points to the used prompt template so handle it as prompt tool
if tool_type == ToolType.CUSTOM_LLM:
tool_type = ToolType.PROMPT

if pydash.get(_node, "source.type") not in ["code", "package_with_prompt"]:
return
pair = (source, tool_type.lower())
if pair not in tools:
tools.append(pair)

source_path_mapping[source].append(f"{_node_path}.source.path")


def _get_involved_code_and_package(
data: dict,
) -> Tuple[List[Tuple[str, str]], Set[str], Dict[str, List[str]]]:
tools = [] # List[Tuple[source_file, tool_type]]
used_packages = set()
source_path_mapping = collections.defaultdict(list)

for node_i, node in enumerate(data[NODES]):
_update_involved_tools_and_packages(
node,
f"{NODES}.{node_i}",
tools=tools,
used_packages=used_packages,
source_path_mapping=source_path_mapping,
)

# understand DAG to parse variants
# TODO: should we allow source to appear both in node and node variants?
if node.get(USE_VARIANTS) is True:
node_variants = data[NODE_VARIANTS][node["name"]]
for variant_id in node_variants[VARIANTS]:
node_with_variant = node_variants[VARIANTS][variant_id][NODE]
_update_involved_tools_and_packages(
node_with_variant,
f"{NODE_VARIANTS}.{node['name']}.{VARIANTS}.{variant_id}.{NODE}",
tools=tools,
used_packages=used_packages,
source_path_mapping=source_path_mapping,
)
if None in used_packages:
used_packages.remove(None)
return tools, used_packages, source_path_mapping


def generate_flow_tools_json(
flow_directory: Union[str, Path],
dump: bool = True,
Expand All @@ -713,63 +800,42 @@ def generate_flow_tools_json(
:param raise_error: whether to raise the error, default value is True.
:param timeout: timeout for generation, default value is 60 seconds.
:param include_errors_in_output: whether to include error messages in output, default value is False.
:param target_source: the source name to filter result, default value is None.
:param target_source: the source name to filter result, default value is None. Note that we will update system path
in coroutine if target_source is provided given it's expected to be from a specific cli call.
:param used_packages_only: whether to only include used packages, default value is False.
:param source_path_mapping: if specified, record yaml paths for each source.
"""
flow_directory = Path(flow_directory).resolve()
# parse flow DAG
with open(flow_directory / DAG_FILE_NAME, "r", encoding=DEFAULT_ENCODING) as f:
data = yaml.safe_load(f)
tools = [] # List[Tuple[source_file, tool_type]]
used_packages = set()

def process_node(_node, _node_path):
source, tool_type = pydash.get(_node, "source.path", None), _node.get("type", None)
if target_source and source != target_source:
return
used_packages.add(pydash.get(_node, "source.tool", None))

if source is None or tool_type is None:
return

if tool_type == ToolType.CUSTOM_LLM:
tool_type = ToolType.PROMPT
tools, used_packages, _source_path_mapping = _get_involved_code_and_package(data)

if pydash.get(_node, "source.type") not in ["code", "package_with_prompt"]:
return
tools.append((source, tool_type.lower()))
if source_path_mapping is not None:
if source not in source_path_mapping:
source_path_mapping[source] = []
# update passed in source_path_mapping if specified
if source_path_mapping is not None:
source_path_mapping.update(_source_path_mapping)

source_path_mapping[source].append(f"{_node_path}.source.path")

for node_i, node in enumerate(data[NODES]):
process_node(node, f"{NODES}.{node_i}")

# understand DAG to parse variants
# TODO: should we allow source to appear both in node and node variants?
if node.get(USE_VARIANTS) is True:
node_variants = data[NODE_VARIANTS][node["name"]]
for variant_id in node_variants[VARIANTS]:
current_node = node_variants[VARIANTS][variant_id][NODE]
process_node(current_node, f"{NODE_VARIANTS}.{node['name']}.{VARIANTS}.{variant_id}.{NODE}")

if None in used_packages:
used_packages.remove(None)
# filter tools by target_source if specified
if target_source is not None:
tools = list(filter(lambda x: x[0] == target_source, tools))

# generate content
# TODO: remove type in tools (input) and code (output)
flow_tools = {
"package": _generate_package_tools(keys=list(used_packages) if used_packages_only else None),
"code": _generate_tool_meta(
flow_directory,
tools,
raise_error=raise_error,
timeout=timeout,
include_errors_in_output=include_errors_in_output,
# we don't need to protect system path according to the target usage when target_source is specified
load_in_subprocess=target_source is None,
),
# specified source may only appear in code tools
"package": {}
if target_source is not None
else _generate_package_tools(keys=list(used_packages) if used_packages_only else None),
}

if dump:
Expand Down
24 changes: 15 additions & 9 deletions src/promptflow/promptflow/_sdk/operations/_flow_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import yaml

from promptflow._sdk._constants import CHAT_HISTORY, DEFAULT_ENCODING, LOCAL_MGMT_DB_PATH
from promptflow._sdk._constants import CHAT_HISTORY, DEFAULT_ENCODING, FLOW_TOOLS_JSON_GEN_TIMEOUT, LOCAL_MGMT_DB_PATH
from promptflow._sdk._load_functions import load_flow
from promptflow._sdk._submitter import TestSubmitter
from promptflow._sdk._utils import (
Expand All @@ -27,6 +27,7 @@
generate_random_string,
parse_variant,
)
from promptflow._sdk.entities._flow import ProtectedFlow
from promptflow._sdk.entities._validation import ValidationResult
from promptflow._telemetry.activity import ActivityType, monitor_operation
from promptflow._telemetry.telemetry import TelemetryMixin
Expand Down Expand Up @@ -609,18 +610,18 @@ def validate(self, flow: Union[str, PathLike], *, raise_error: bool = False, **k
:rtype: ValidationResult
"""

flow = load_flow(source=flow)
flow_entity: ProtectedFlow = load_flow(source=flow)

# TODO: put off this if we do path existence check in FlowSchema on fields other than additional_includes
validation_result = flow._validate()
validation_result = flow_entity._validate()

source_path_mapping = {}
flow_tools, tools_errors = self._generate_tools_meta(
flow=flow.flow_dag_path,
flow=flow_entity.flow_dag_path,
source_path_mapping=source_path_mapping,
)

flow.tools_meta_path.write_text(
flow_entity.tools_meta_path.write_text(
data=json.dumps(flow_tools, indent=4),
encoding=DEFAULT_ENCODING,
)
Expand All @@ -634,21 +635,23 @@ def validate(self, flow: Union[str, PathLike], *, raise_error: bool = False, **k
)

# flow in control plane is read-only, so resolve location makes sense even in SDK experience
validation_result.resolve_location_for_diagnostics(flow.flow_dag_path)
validation_result.resolve_location_for_diagnostics(flow_entity.flow_dag_path.as_posix())

flow._try_raise(
flow_entity._try_raise(
validation_result,
raise_error=raise_error,
)

return validation_result

@monitor_operation(activity_name="pf.flows._generate_tools_meta", activity_type=ActivityType.INTERNALCALL)
def _generate_tools_meta(
self,
flow: Union[str, PathLike],
*,
source_name: str = None,
source_path_mapping: Dict[str, List[str]] = None,
timeout: int = FLOW_TOOLS_JSON_GEN_TIMEOUT,
) -> Tuple[dict, dict]:
"""Generate flow tools meta for a specific flow or a specific node in the flow.

Expand All @@ -663,12 +666,14 @@ def _generate_tools_meta(
:param source_name: source name to generate tools meta. If not specified, generate tools meta for all sources.
:type source_name: str
:param source_path_mapping: If passed in None, do nothing; if passed in a dict, will record all reference yaml
paths for each source.
paths for each source in the dict passed in.
:type source_path_mapping: Dict[str, List[str]]
:param timeout: timeout for generating tools meta
:type timeout: int
:return: dict of tools meta and dict of tools errors
:rtype: Tuple[dict, dict]
"""
flow = load_flow(source=flow)
flow: ProtectedFlow = load_flow(source=flow)

with self._resolve_additional_includes(flow.flow_dag_path) as new_flow_dag_path:
flow_tools = generate_flow_tools_json(
Expand All @@ -679,6 +684,7 @@ def _generate_tools_meta(
target_source=source_name,
used_packages_only=True,
source_path_mapping=source_path_mapping,
timeout=timeout,
)

flow_tools_meta = flow_tools.pop("code", {})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def test_flow_validation_failed(self, pf) -> None:
"convert_to_dict.py": {
"function": "convert_to_dict",
"inputs": {"input_str": {"type": ["string"]}},
"source": os.path.join("..", "external_files", "convert_to_dict.py"),
"source": "convert_to_dict.py",
"type": "python",
},
"fetch_text_content_from_url.py": {
Expand Down Expand Up @@ -365,7 +365,7 @@ def test_flow_generate_tools_meta(self, pf) -> None:
"convert_to_dict.py": {
"function": "convert_to_dict",
"inputs": {"input_str": {"type": ["string"]}},
"source": os.path.join("..", "external_files", "convert_to_dict.py"),
"source": "convert_to_dict.py",
"type": "python",
},
"fetch_text_content_from_url.py": {
Expand Down Expand Up @@ -406,6 +406,23 @@ def test_flow_generate_tools_meta(self, pf) -> None:
}
assert tools_error == {}

@pytest.mark.skip(reason="It will fail in CI for some reasons. Still need to investigate.")
def test_flow_generate_tools_meta_timeout(self, pf) -> None:
source = f"{FLOWS_DIR}/web_classification_invalid"

for tools_meta, tools_error in [
pf.flows._generate_tools_meta(source, timeout=1),
# There is no built-in method to forcefully stop a running thread in Python
# because abruptly stopping a thread can cause issues like resource leaks,
# deadlocks, or inconsistent states.
# Caller (VSCode extension) will handle the timeout error.
# pf.flows._generate_tools_meta(source, source_name="convert_to_dict.py", timeout=1),
]:
assert tools_meta == {"code": {}, "package": {}}
assert tools_error
for error in tools_error.values():
assert "timeout" in error

def test_flow_generate_tools_meta_with_pkg_tool_with_custom_strong_type_connection(self, pf) -> None:
source = f"{FLOWS_DIR}/flow_with_package_tool_with_custom_strong_type_connection"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import json
import time

from promptflow import tool


# use this to test the timeout
time.sleep(2)


@tool
def convert_to_dict(input_str: str):
try:
return json.loads(input_str)
except Exception as e:
print("input is not valid, error: {}".format(e))
return {"category": "None", "evidence": "None"}
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,4 @@ node_variants:
api: completion
module: promptflow.tools.aoai
additional_includes:
- ../external_files/convert_to_dict.py
- ../external_files/fetch_text_content_from_url.py
Loading