Skip to content

Commit

Permalink
Merge branch 'main' into brynn/separate-core-connection
Browse files Browse the repository at this point in the history
Signed-off-by: Brynn Yin <[email protected]>
  • Loading branch information
brynn-code committed Mar 12, 2024
2 parents cc066d9 + 39039cd commit 4469654
Show file tree
Hide file tree
Showing 29 changed files with 315 additions and 180 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,21 @@ pip install my-tools-package>=0.0.8
### I'm a tool author, and want to dynamically list Azure resources in my tool input. What should I pay attention to?
1. Clarify azure workspace triple "subscription_id", "resource_group_name", "workspace_name" in the list function signature. System helps append workspace triple to function input parameters if they are in function signature. See [list_endpoint_names](https://github.com/microsoft/promptflow/blob/main/examples/tools/tool-package-quickstart/my_tool_package/tools/tool_with_dynamic_list_input.py) as an example.
```python
def list_endpoint_names(subscription_id, resource_group_name, workspace_name, prefix: str = "") -> List[Dict[str, str]]:
def list_endpoint_names(subscription_id: str = None,
resource_group_name: str = None,
workspace_name: str = None,
prefix: str = "") -> List[Dict[str, str]]:
"""This is an example to show how to get Azure ML resource in tool input list function.
:param subscription_id: Azure subscription id.
:param resource_group_name: Azure resource group name.
:param workspace_name: Azure ML workspace name.
:param prefix: prefix to add to each item.
"""
# return an empty list if workspace triad is not available.
if not subscription_id or not resource_group_name or not workspace_name:
return []
from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential
Expand Down Expand Up @@ -185,4 +192,13 @@ If you are unable to see any options in a dynamic list tool input, you may see a
If this occurs, follow these troubleshooting steps:

- Note the exact error message shown. This provides details on why the dynamic list failed to populate.
- Check the tool documentation for any prerequisites or special instructions. For example, if the dynamic list function requires Azure credentials, ensure you have installed azure dependencies, logged in and set the default workspace.
```sh
pip install azure-ai-ml
```
```sh
az login
az account set --subscription <subscription_id>
az configure --defaults group=<resource_group_name> workspace=<workspace_name>
```
- Contact the tool author/support team and report the issue. Provide the error message so they can investigate the root cause.
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,21 @@ def my_list_func(prefix: str = "", size: int = 10, **kwargs) -> List[Dict[str, U
return result


def list_endpoint_names(subscription_id, resource_group_name, workspace_name, prefix: str = "") -> List[Dict[str, str]]:
def list_endpoint_names(subscription_id: str = None,
resource_group_name: str = None,
workspace_name: str = None,
prefix: str = "") -> List[Dict[str, str]]:
"""This is an example to show how to get Azure ML resource in tool input list function.
:param subscription_id: Azure subscription id.
:param resource_group_name: Azure resource group name.
:param workspace_name: Azure ML workspace name.
:param prefix: prefix to add to each item.
"""
# return an empty list if workspace triad is not available.
if not subscription_id or not resource_group_name or not workspace_name:
return []

from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential

Expand Down
20 changes: 14 additions & 6 deletions src/promptflow-tools/promptflow/tools/open_model_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,11 +521,15 @@ def parse_endpoint_connection_type(endpoint_connection_name: str) -> Tuple[str,
return (endpoint_connection_details[0].lower(), endpoint_connection_details[1])


def list_endpoint_names(subscription_id: str,
resource_group_name: str,
workspace_name: str,
def list_endpoint_names(subscription_id: str = None,
resource_group_name: str = None,
workspace_name: str = None,
return_endpoint_url: bool = False,
force_refresh: bool = False) -> List[Dict[str, Union[str, int, float, list, Dict]]]:
# return an empty list if workspace triad is not available.
if not subscription_id or not resource_group_name or not workspace_name:
return []

cache_file_path = None
try:
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
Expand Down Expand Up @@ -598,10 +602,14 @@ def list_endpoint_names(subscription_id: str,
return list_of_endpoints


def list_deployment_names(subscription_id: str,
resource_group_name: str,
workspace_name: str,
def list_deployment_names(subscription_id: str = None,
resource_group_name: str = None,
workspace_name: str = None,
endpoint: str = None) -> List[Dict[str, Union[str, int, float, list, Dict]]]:
# return an empty list if workspace triad is not available.
if not subscription_id or not resource_group_name or not workspace_name:
return []

deployment_default_list = [{
"value": DEPLOYMENT_DEFAULT,
"display_value": DEPLOYMENT_DEFAULT,
Expand Down
16 changes: 7 additions & 9 deletions src/promptflow/promptflow/_cli/_pf/_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def add_experiment_start(subparsers):
# Start a named experiment:
pf experiment start -n my_experiment --inputs data1=data1_val data2=data2_val
# Run an experiment by yaml file:
pf experiment start --file path/to/my_experiment.exp.yaml --inputs data1=data1_val data2=data2_val
pf experiment start --template path/to/my_experiment.exp.yaml --inputs data1=data1_val data2=data2_val
"""
activate_action(
name="start",
description="Start an experiment.",
epilog=epilog,
add_params=[add_param_name, add_param_file, add_param_input, add_param_stream] + base_params,
add_params=[add_param_name, add_param_template, add_param_input, add_param_stream] + base_params,
subparsers=subparsers,
help_message="Start an experiment.",
action_param_name="sub_action",
Expand Down Expand Up @@ -235,20 +235,18 @@ def start_experiment(args: argparse.Namespace):
if args.name:
logger.debug(f"Starting a named experiment {args.name}.")
inputs = list_of_dict_to_dict(args.inputs)
if inputs:
logger.warning("The inputs of named experiment cannot be modified.")
client = _get_pf_client()
experiment = client._experiments.get(args.name)
result = client._experiments.start(experiment=experiment, stream=args.stream)
elif args.file:
result = client._experiments.start(experiment=experiment, inputs=inputs, stream=args.stream)
elif args.template:
from promptflow._sdk._load_functions import _load_experiment

logger.debug(f"Starting an anonymous experiment {args.file}.")
experiment = _load_experiment(source=args.file)
logger.debug(f"Starting an anonymous experiment {args.template}.")
experiment = _load_experiment(source=args.template)
inputs = list_of_dict_to_dict(args.inputs)
result = _get_pf_client()._experiments.start(experiment=experiment, inputs=inputs, stream=args.stream)
else:
raise UserErrorException("To start an experiment, one of [name, file] must be specified.")
raise UserErrorException("To start an experiment, one of [name, template] must be specified.")
print(json.dumps(result._to_dict(), indent=4))


Expand Down
9 changes: 9 additions & 0 deletions src/promptflow/promptflow/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PROMPTFLOW_CONNECTIONS = "PROMPTFLOW_CONNECTIONS"
PROMPTFLOW_SECRETS_FILE = "PROMPTFLOW_SECRETS_FILE"
PF_NO_INTERACTIVE_LOGIN = "PF_NO_INTERACTIVE_LOGIN"
PF_RUN_AS_BUILT_BINARY = "PF_RUN_AS_BUILT_BINARY"
PF_LOGGING_LEVEL = "PF_LOGGING_LEVEL"
OPENAI_API_KEY = "openai-api-key"
BING_API_KEY = "bing-api-key"
Expand All @@ -19,6 +20,7 @@
ERROR_RESPONSE_COMPONENT_NAME = "promptflow"
EXTENSION_UA = "prompt-flow-extension"
LANGUAGE_KEY = "language"
USER_AGENT_OVERRIDE_KEY = "user_agent_override"

# Tool meta info
ICON_DARK = "icon_dark"
Expand Down Expand Up @@ -166,6 +168,13 @@ class MessageFormatType:


DEFAULT_OUTPUT_NAME = "output"
OUTPUT_FILE_NAME = "output.jsonl"


class OutputsFolderName:
FLOW_OUTPUTS = "flow_outputs"
FLOW_ARTIFACTS = "flow_artifacts"
NODE_ARTIFACTS = "node_artifacts"


class ConnectionType(str, Enum):
Expand Down
3 changes: 3 additions & 0 deletions src/promptflow/promptflow/_core/operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def append_user_agent(self, user_agent: str):
user_agent (str): The user agent information to append.
"""
if OperationContext.USER_AGENT_KEY in self:
# TODO: this judgement can be wrong when an user agent is a substring of another,
# e.g. "Mozilla/5.0" and "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"
# however, changing this code may impact existing logic, so won't change it now
if user_agent not in self.user_agent:
self.user_agent = f"{self.user_agent.strip()} {user_agent.strip()}"
else:
Expand Down
7 changes: 1 addition & 6 deletions src/promptflow/promptflow/_sdk/_load_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import hashlib
from os import PathLike
from pathlib import Path
from typing import IO, AnyStr, Optional, Union
Expand All @@ -11,7 +10,6 @@
from .._utils.logger_utils import get_cli_sdk_logger
from .._utils.yaml_utils import load_yaml
from ._errors import MultipleExperimentTemplateError, NoExperimentTemplateError
from ._utils import _sanitize_python_variable_name
from .entities import Run
from .entities._connection import CustomConnection, _Connection
from .entities._experiment import Experiment, ExperimentTemplate
Expand Down Expand Up @@ -205,8 +203,5 @@ def _load_experiment(
absolute_path = source.resolve().absolute().as_posix()
if not source.exists():
raise NoExperimentTemplateError(f"Experiment file {absolute_path} not found.")
anonymous_exp_name = _sanitize_python_variable_name(
f"{source.stem}_{hashlib.sha1(absolute_path.encode('utf-8')).hexdigest()}"
)
experiment = load_common(Experiment, source, params_override=[{"name": anonymous_exp_name}], **kwargs)
experiment = load_common(Experiment, source, **kwargs)
return experiment
51 changes: 45 additions & 6 deletions src/promptflow/promptflow/_sdk/_pf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from pathlib import Path
from typing import Any, Dict, List, Union

from .._constants import USER_AGENT_OVERRIDE_KEY
from .._utils.logger_utils import get_cli_sdk_logger
from ..exceptions import ErrorTarget, UserErrorException
from ._configuration import Configuration
from ._constants import MAX_SHOW_DETAILS_RESULTS
from ._constants import MAX_SHOW_DETAILS_RESULTS, ConnectionProvider
from ._load_functions import load_flow
from ._user_agent import USER_AGENT
from ._utils import ClientUserAgentUtil, get_connection_operation, setup_user_agent_to_operation_context
from ._utils import ClientUserAgentUtil, setup_user_agent_to_operation_context
from .entities import Run
from .entities._eager_flow import EagerFlow
from .operations import RunOperations
Expand All @@ -34,20 +36,25 @@ class PFClient:

def __init__(self, **kwargs):
logger.debug("PFClient init with kwargs: %s", kwargs)
self._runs = RunOperations(self)
# when this is set, telemetry from this client will use this user agent and ignore the one from OperationContext
self._user_agent_override = kwargs.pop(USER_AGENT_OVERRIDE_KEY, None)
self._connection_provider = kwargs.pop("connection_provider", None)
self._config = kwargs.get("config", None) or {}
# The credential is used as an option to override
# DefaultAzureCredential when using workspace connection provider
self._credential = kwargs.get("credential", None)

# user_agent_override will be applied to all TelemetryMixin operations
self._runs = RunOperations(self, user_agent_override=self._user_agent_override)
self._flows = FlowOperations(client=self, user_agent_override=self._user_agent_override)
self._experiments = ExperimentOperations(self, user_agent_override=self._user_agent_override)
# Lazy init to avoid azure credential requires too early
self._connections = None
self._flows = FlowOperations(client=self)

self._tools = ToolOperations()
# add user agent from kwargs if any
if isinstance(kwargs.get("user_agent"), str):
ClientUserAgentUtil.append_user_agent(kwargs["user_agent"])
self._experiments = ExperimentOperations(self)
self._traces = TraceOperations()
setup_user_agent_to_operation_context(USER_AGENT)

Expand Down Expand Up @@ -243,9 +250,41 @@ def connections(self) -> ConnectionOperations:
"""Connection operations that can manage connections."""
if not self._connections:
self._ensure_connection_provider()
self._connections = get_connection_operation(self._connection_provider, self._credential)
self._connections = PFClient._build_connection_operation(
self._connection_provider,
self._credential,
user_agent_override=self._user_agent_override,
)
return self._connections

@staticmethod
def _build_connection_operation(connection_provider: str, credential=None, **kwargs):
"""
Build a ConnectionOperation object based on connection provider.
:param connection_provider: Connection provider, e.g. local, azureml, azureml://subscriptions..., etc.
:type connection_provider: str
:param credential: Credential when remote provider, default to chained credential DefaultAzureCredential.
:type credential: object
"""
if connection_provider == ConnectionProvider.LOCAL.value:
from promptflow._sdk.operations._connection_operations import ConnectionOperations

logger.debug("PFClient using local connection operations.")
connection_operation = ConnectionOperations(**kwargs)
elif connection_provider.startswith(ConnectionProvider.AZUREML.value):
from promptflow._sdk.operations._local_azure_connection_operations import LocalAzureConnectionOperations

logger.debug(f"PFClient using local azure connection operations with credential {credential}.")
connection_operation = LocalAzureConnectionOperations(connection_provider, credential=credential, **kwargs)
else:
raise UserErrorException(
target=ErrorTarget.CONTROL_PLANE_SDK,
message_format="Unsupported connection provider: {connection_provider}",
connection_provider=connection_provider,
)
return connection_operation

@property
def flows(self) -> FlowOperations:
"""Operations on the flow that can manage flows."""
Expand Down
35 changes: 26 additions & 9 deletions src/promptflow/promptflow/_sdk/_service/apis/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import traceback
from datetime import datetime
from typing import Callable

from flask import request
from google.protobuf.json_format import MessageToJson
Expand All @@ -28,7 +29,15 @@
from promptflow._utils.thread_utils import ThreadWithContextVars


def trace_collector(logger: logging.Logger):
def trace_collector(get_created_by_info_with_cache: Callable, logger: logging.Logger):
"""
This function is target to be reused in other places, so pass in get_created_by_info_with_cache and logger to avoid
app related dependencies.
Args:
get_created_by_info_with_cache (Callable): A function that retrieves information about the creator of the trace.
logger (logging.Logger): The logger object used for logging.
"""
content_type = request.headers.get("Content-Type")
# binary protobuf encoding
if "application/x-protobuf" in content_type:
Expand All @@ -55,15 +64,17 @@ def trace_collector(logger: logging.Logger):
all_spans.append(span)

# Create a new thread to write trace to cosmosdb to avoid blocking the main thread
ThreadWithContextVars(target=_try_write_trace_to_cosmosdb, args=(all_spans, logger)).start()
ThreadWithContextVars(
target=_try_write_trace_to_cosmosdb, args=(all_spans, get_created_by_info_with_cache, logger)
).start()
return "Traces received", 200

# JSON protobuf encoding
elif "application/json" in content_type:
raise NotImplementedError


def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger):
def _try_write_trace_to_cosmosdb(all_spans, get_created_by_info_with_cache: Callable, logger: logging.Logger):
if not all_spans:
return
try:
Expand All @@ -78,31 +89,37 @@ def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger):

logger.info(f"Start writing trace to cosmosdb, total spans count: {len(all_spans)}.")
start_time = datetime.now()
from promptflow._sdk._service.app import CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE
from promptflow.azure._storage.cosmosdb.client import get_client
from promptflow.azure._storage.cosmosdb.span import Span as SpanCosmosDB
from promptflow.azure._storage.cosmosdb.summary import Summary

# Load span and summary clients first time may slow.
# So, we load 2 client in parallel for warm up.
span_thread = ThreadWithContextVars(
span_client_thread = ThreadWithContextVars(
target=get_client, args=(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name)
)
span_thread.start()
span_client_thread.start()

# Load created_by info first time may slow. So, we load it in parallel for warm up.
created_by_thread = ThreadWithContextVars(target=get_created_by_info_with_cache)
created_by_thread.start()

get_client(CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name)

span_thread.join()
span_client_thread.join()
created_by_thread.join()

created_by = get_created_by_info_with_cache()

for span in all_spans:
span_client = get_client(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name)
result = SpanCosmosDB(span, CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE).persist(span_client)
result = SpanCosmosDB(span, created_by).persist(span_client)
# None means the span already exists, then we don't need to persist the summary also.
if result is not None:
line_summary_client = get_client(
CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name
)
Summary(span, CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE, logger).persist(line_summary_client)
Summary(span, created_by, logger).persist(line_summary_client)
logger.info(
(
f"Finish writing trace to cosmosdb, total spans count: {len(all_spans)}."
Expand Down
Loading

0 comments on commit 4469654

Please sign in to comment.