Skip to content

Commit

Permalink
Merge branch 'main' into yigao/move_devkit_azure_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
crazygao committed Apr 11, 2024
2 parents 9340ac3 + a7ba06d commit b43d162
Show file tree
Hide file tree
Showing 29 changed files with 486 additions and 180 deletions.
52 changes: 51 additions & 1 deletion scripts/json_schema/Flow.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
"type": "object",
"additionalProperties": {}
},
"environment_variables": {
"title": "environment_variables",
"type": "object",
"additionalProperties": {}
},
"inputs": {
"title": "inputs",
"type": "object",
Expand All @@ -34,7 +39,13 @@
},
"language": {
"title": "language",
"type": "string"
"type": "string",
"default": "python",
"enum": [
"python",
"csharp"
],
"enumNames": []
},
"node_variants": {
"title": "node_variants",
Expand Down Expand Up @@ -116,8 +127,47 @@
"type": "object",
"additionalProperties": {}
},
"environment_variables": {
"title": "environment_variables",
"type": "object",
"additionalProperties": {}
},
"init": {
"title": "init",
"type": "object",
"additionalProperties": {
"type": "object",
"$ref": "#/definitions/FlexFlowInitSchema"
}
},
"inputs": {
"title": "inputs",
"type": "object",
"additionalProperties": {
"type": "object",
"$ref": "#/definitions/FlexFlowInputSchema"
}
},
"language": {
"title": "language",
"type": "string",
"default": "python",
"enum": [
"python",
"csharp"
],
"enumNames": []
},
"outputs": {
"title": "outputs",
"type": "object",
"additionalProperties": {
"type": "object",
"$ref": "#/definitions/FlexFlowOutputSchema"
}
},
"sample": {
"title": "sample",
"type": "string"
},
"$schema": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,15 @@ def _get_resource_token(
) -> object:
from promptflow.azure import PFClient

# The default connection_time and read_timeout are both 300s.
# The get token operation should be fast, so we set a short timeout.
pf_client = PFClient(
credential=credential,
subscription_id=subscription_id,
resource_group_name=resource_group_name,
workspace_name=workspace_name,
connection_timeout=15.0,
read_timeout=30.0,
)

token_resp = pf_client._traces._get_cosmos_db_token(container_name=container_name, acquire_write=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from promptflow._constants import CONNECTION_NAME_PROPERTY, CONNECTION_SECRET_KEYS, CustomStrongTypeConnectionConfigs
from promptflow._utils.utils import try_import
from promptflow.contracts.tool import ConnectionType
from promptflow.contracts.types import Secret

from ._connection_provider import ConnectionProvider
Expand Down Expand Up @@ -47,7 +48,7 @@ def _build_connection(connection_dict: dict):
secret_keys = connection_dict.get("secret_keys", [])
secrets = {k: v for k, v in value.items() if k in secret_keys}
configs = {k: v for k, v in value.items() if k not in secrets}
connection_value = connection_class(configs=configs, secrets=secrets, name=key)
connection_value = connection_class(configs=configs, secrets=secrets, name=name)
if CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY in configs:
connection_value.custom_type = configs[CustomStrongTypeConnectionConfigs.PROMPTFLOW_TYPE_KEY]
else:
Expand Down Expand Up @@ -85,7 +86,7 @@ def _build_connections(cls, _dict: Mapping[str, dict]):
def import_requisites(cls, _dict: Mapping[str, dict]):
"""Import connection required modules."""
modules = set()
for key, connection_dict in _dict.items():
for _, connection_dict in _dict.items():
module = connection_dict.get("module")
if module:
modules.add(module)
Expand All @@ -96,5 +97,9 @@ def import_requisites(cls, _dict: Mapping[str, dict]):
def list(self):
return [c for c in self._connections.values()]

def get(self, name: str, **kwargs) -> Any:
return self._connections.get(name)
def get(self, name: str) -> Any:
if isinstance(name, str):
return self._connections.get(name)
elif ConnectionType.is_connection_value(name):
return name
return None
12 changes: 6 additions & 6 deletions src/promptflow-core/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from promptflow._constants import MessageFormatType
from promptflow._core._errors import InvalidSource
from promptflow._core.connection_manager import ConnectionManager
from promptflow._core.tool import STREAMING_OPTION_PARAMETER_ATTR
from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping
from promptflow._utils.multimedia_utils import MultimediaProcessor
Expand All @@ -23,6 +22,7 @@
get_prompt_param_name_from_func,
)
from promptflow._utils.yaml_utils import load_yaml
from promptflow.connections import ConnectionProvider
from promptflow.contracts.flow import InputAssignment, InputValueType, Node, ToolSource, ToolSourceType
from promptflow.contracts.tool import ConnectionType, Tool, ToolType, ValueType
from promptflow.contracts.types import AssistantDefinition, PromptTemplate
Expand Down Expand Up @@ -60,7 +60,7 @@ class ToolResolver:
def __init__(
self,
working_dir: Path,
connections: Optional[dict] = None,
connection_provider: Optional[ConnectionProvider] = None,
package_tool_keys: Optional[List[str]] = None,
message_format: str = MessageFormatType.BASIC,
):
Expand All @@ -71,7 +71,7 @@ def __init__(
pass
self._tool_loader = ToolLoader(working_dir, package_tool_keys=package_tool_keys)
self._working_dir = working_dir
self._connection_manager = ConnectionManager(connections)
self._connection_provider = connection_provider
self._multimedia_processor = MultimediaProcessor.create(message_format)

@classmethod
Expand All @@ -83,7 +83,7 @@ def start_resolver(
return resolver

def _convert_to_connection_value(self, k: str, v: InputAssignment, node_name: str, conn_types: List[ValueType]):
connection_value = self._connection_manager.get(v.value)
connection_value = self._connection_provider.get(v.value)
if not connection_value:
raise ConnectionNotFound(f"Connection {v.value} not found for node {node_name!r} input {k!r}.")
# Check if type matched
Expand All @@ -108,7 +108,7 @@ def _convert_to_custom_strong_type_connection_value(
if not conn_types:
msg = f"Input '{k}' for node '{node_name}' has invalid types: {conn_types}."
raise NodeInputValidationError(message=msg)
connection_value = self._connection_manager.get(v.value)
connection_value = self._connection_provider.get(v.value)
if not connection_value:
raise ConnectionNotFound(f"Connection {v.value} not found for node {node_name!r} input {k!r}.")

Expand Down Expand Up @@ -476,7 +476,7 @@ def _remove_init_args(node_inputs: dict, init_args: dict):
del node_inputs[k]

def _get_llm_node_connection(self, node: Node):
connection = self._connection_manager.get(node.connection)
connection = self._connection_provider.get(node.connection)
if connection is None:
raise ConnectionNotFound(
message_format="Connection '{connection}' of LLM node '{node_name}' is not found.",
Expand Down
25 changes: 18 additions & 7 deletions src/promptflow-core/promptflow/executor/flow_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pathlib import Path
from threading import current_thread
from types import GeneratorType
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

import opentelemetry.trace as otel_trace
from opentelemetry.trace.status import StatusCode
Expand All @@ -41,9 +41,11 @@
from promptflow._utils.user_agent_utils import append_promptflow_package_ua
from promptflow._utils.utils import get_int_env_var, transpose
from promptflow._utils.yaml_utils import load_yaml
from promptflow.connections import ConnectionProvider
from promptflow.contracts.flow import Flow, FlowInputDefinition, InputAssignment, InputValueType, Node
from promptflow.contracts.run_info import FlowRunInfo, Status
from promptflow.contracts.run_mode import RunMode
from promptflow.core._connection_provider._dict_connection_provider import DictConnectionProvider
from promptflow.exceptions import PromptflowException
from promptflow.executor import _input_assignment_parser
from promptflow.executor._async_nodes_scheduler import AsyncNodesScheduler
Expand Down Expand Up @@ -98,7 +100,7 @@ class FlowExecutor:
def __init__(
self,
flow: Flow,
connections: dict,
connections: ConnectionProvider,
run_tracker: RunTracker,
cache_manager: AbstractCacheManager,
loaded_tools: Mapping[str, Callable],
Expand All @@ -113,7 +115,7 @@ def __init__(
:param flow: The Flow object to execute.
:type flow: ~promptflow.contracts.flow.Flow
:param connections: The connections between nodes in the Flow.
:type connections: dict
:type connections: Union[dict, ConnectionProvider]
:param run_tracker: The RunTracker object to track the execution of the Flow.
:type run_tracker: ~promptflow._core.run_tracker.RunTracker
:param cache_manager: The AbstractCacheManager object to manage caching of results.
Expand Down Expand Up @@ -170,7 +172,7 @@ def __init__(
def create(
cls,
flow_file: Path,
connections: dict,
connections: Union[dict, ConnectionProvider],
working_dir: Optional[Path] = None,
*,
entry: Optional[str] = None,
Expand All @@ -186,7 +188,7 @@ def create(
:param flow_file: The path to the flow file.
:type flow_file: Path
:param connections: The connections to be used for the flow.
:type connections: dict
:type connections: Union[dict, ConnectionProvider]
:param working_dir: The working directory to be used for the flow. Default is None.
:type working_dir: Optional[str]
:param func: The function to be used for the flow if .py is provided. Default is None.
Expand Down Expand Up @@ -246,7 +248,7 @@ def create(
def _create_from_flow(
cls,
flow: Flow,
connections: dict,
connections: Union[dict, ConnectionProvider],
working_dir: Optional[Path],
*,
flow_file: Optional[Path] = None,
Expand All @@ -262,6 +264,8 @@ def _create_from_flow(
flow = flow._apply_default_node_variants()

package_tool_keys = [node.source.tool for node in flow.nodes if node.source and node.source.tool]
if isinstance(connections, dict):
connections = DictConnectionProvider(connections)
tool_resolver = ToolResolver(working_dir, connections, package_tool_keys, message_format=flow.message_format)

with _change_working_dir(working_dir):
Expand Down Expand Up @@ -393,6 +397,8 @@ def update_operation_context():
inputs = multimedia_processor.load_multimedia_data(node_referenced_flow_inputs, converted_flow_inputs_for_node)
dependency_nodes_outputs = multimedia_processor.load_multimedia_data_recursively(dependency_nodes_outputs)
package_tool_keys = [node.source.tool] if node.source and node.source.tool else []
if isinstance(connections, dict):
connections = DictConnectionProvider(connections)
tool_resolver = ToolResolver(working_dir, connections, package_tool_keys, message_format=flow.message_format)
resolved_node = tool_resolver.resolve_tool_by_node(node)

Expand Down Expand Up @@ -1329,6 +1335,7 @@ def execute_flow(
run_aggregation: bool = True,
enable_stream_output: bool = False,
allow_generator_output: bool = False, # TODO: remove this
init_kwargs: Optional[dict] = None,
**kwargs,
) -> LineResult:
"""Execute the flow, including aggregation nodes.
Expand All @@ -1347,12 +1354,16 @@ def execute_flow(
:type enable_stream_output: Optional[bool]
:param run_id: Run id will be set in operation context and used for session.
:type run_id: Optional[str]
:param init_kwargs: Initialization parameters for flex flow, only supported when flow is callable class.
:type init_kwargs: dict
:param kwargs: Other keyword arguments to create flow executor.
:type kwargs: Any
:return: The line result of executing the flow.
:rtype: ~promptflow.executor._result.LineResult
"""
flow_executor = FlowExecutor.create(flow_file, connections, working_dir, raise_ex=False, **kwargs)
flow_executor = FlowExecutor.create(
flow_file, connections, working_dir, raise_ex=False, init_kwargs=init_kwargs, **kwargs
)
flow_executor.enable_streaming_for_llm_flow(lambda: enable_stream_output)
with _change_working_dir(working_dir), _force_flush_tracer_provider():
# Execute nodes in the flow except the aggregation nodes
Expand Down
4 changes: 4 additions & 0 deletions src/promptflow-devkit/promptflow/_cli/_pf/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def add_parser_test_flow(subparsers):
pf flow test --flow my-awesome-flow --node node_name
# Chat in the flow:
pf flow test --flow my-awesome-flow --node node_name --interactive
# Test a flow with init kwargs:
pf flow test --flow my-awesome-flow --init key1=value1 key2=value2
""" # noqa: E501
add_param_flow = lambda parser: parser.add_argument( # noqa: E731
"--flow", type=str, required=True, help="the flow directory to test."
Expand Down Expand Up @@ -297,6 +299,7 @@ def add_parser_test_flow(subparsers):
add_param_config,
add_param_detail,
add_param_skip_browser,
add_param_init,
] + base_params

if Configuration.get_instance().is_internal_features_enabled():
Expand Down Expand Up @@ -531,6 +534,7 @@ def _test_flow_standard(args, pf_client, inputs, environment_variables):
stream_output=False,
dump_test_result=True,
output_path=args.detail,
init=list_of_dict_to_dict(args.init),
)
# Print flow/node test result
if isinstance(result, dict):
Expand Down
Loading

0 comments on commit b43d162

Please sign in to comment.