diff --git a/docs/how-to-guides/develop-a-tool/create-dynamic-list-tool-input.md b/docs/how-to-guides/develop-a-tool/create-dynamic-list-tool-input.md index 475236ff813..7cb8850aa05 100644 --- a/docs/how-to-guides/develop-a-tool/create-dynamic-list-tool-input.md +++ b/docs/how-to-guides/develop-a-tool/create-dynamic-list-tool-input.md @@ -125,7 +125,10 @@ pip install my-tools-package>=0.0.8 ### I'm a tool author, and want to dynamically list Azure resources in my tool input. What should I pay attention to? 1. Clarify azure workspace triple "subscription_id", "resource_group_name", "workspace_name" in the list function signature. System helps append workspace triple to function input parameters if they are in function signature. See [list_endpoint_names](https://github.com/microsoft/promptflow/blob/main/examples/tools/tool-package-quickstart/my_tool_package/tools/tool_with_dynamic_list_input.py) as an example. ```python -def list_endpoint_names(subscription_id, resource_group_name, workspace_name, prefix: str = "") -> List[Dict[str, str]]: +def list_endpoint_names(subscription_id: str = None, + resource_group_name: str = None, + workspace_name: str = None, + prefix: str = "") -> List[Dict[str, str]]: """This is an example to show how to get Azure ML resource in tool input list function. :param subscription_id: Azure subscription id. @@ -133,6 +136,10 @@ def list_endpoint_names(subscription_id, resource_group_name, workspace_name, pr :param workspace_name: Azure ML workspace name. :param prefix: prefix to add to each item. """ + # return an empty list if workspace triad is not available. + if not subscription_id or not resource_group_name or not workspace_name: + return [] + from azure.ai.ml import MLClient from azure.identity import DefaultAzureCredential @@ -185,4 +192,13 @@ If you are unable to see any options in a dynamic list tool input, you may see a If this occurs, follow these troubleshooting steps: - Note the exact error message shown. This provides details on why the dynamic list failed to populate. +- Check the tool documentation for any prerequisites or special instructions. For example, if the dynamic list function requires Azure credentials, ensure you have installed azure dependencies, logged in and set the default workspace. + ```sh + pip install azure-ai-ml + ``` + ```sh + az login + az account set --subscription + az configure --defaults group= workspace= + ``` - Contact the tool author/support team and report the issue. Provide the error message so they can investigate the root cause. diff --git a/examples/tools/tool-package-quickstart/my_tool_package/tools/tool_with_dynamic_list_input.py b/examples/tools/tool-package-quickstart/my_tool_package/tools/tool_with_dynamic_list_input.py index e5950b452d1..87b391f67b8 100644 --- a/examples/tools/tool-package-quickstart/my_tool_package/tools/tool_with_dynamic_list_input.py +++ b/examples/tools/tool-package-quickstart/my_tool_package/tools/tool_with_dynamic_list_input.py @@ -31,7 +31,10 @@ def my_list_func(prefix: str = "", size: int = 10, **kwargs) -> List[Dict[str, U return result -def list_endpoint_names(subscription_id, resource_group_name, workspace_name, prefix: str = "") -> List[Dict[str, str]]: +def list_endpoint_names(subscription_id: str = None, + resource_group_name: str = None, + workspace_name: str = None, + prefix: str = "") -> List[Dict[str, str]]: """This is an example to show how to get Azure ML resource in tool input list function. :param subscription_id: Azure subscription id. @@ -39,6 +42,10 @@ def list_endpoint_names(subscription_id, resource_group_name, workspace_name, pr :param workspace_name: Azure ML workspace name. :param prefix: prefix to add to each item. """ + # return an empty list if workspace triad is not available. + if not subscription_id or not resource_group_name or not workspace_name: + return [] + from azure.ai.ml import MLClient from azure.identity import DefaultAzureCredential diff --git a/src/promptflow-tools/promptflow/tools/open_model_llm.py b/src/promptflow-tools/promptflow/tools/open_model_llm.py index 74b1b2088f0..e2c1edf585f 100644 --- a/src/promptflow-tools/promptflow/tools/open_model_llm.py +++ b/src/promptflow-tools/promptflow/tools/open_model_llm.py @@ -521,11 +521,15 @@ def parse_endpoint_connection_type(endpoint_connection_name: str) -> Tuple[str, return (endpoint_connection_details[0].lower(), endpoint_connection_details[1]) -def list_endpoint_names(subscription_id: str, - resource_group_name: str, - workspace_name: str, +def list_endpoint_names(subscription_id: str = None, + resource_group_name: str = None, + workspace_name: str = None, return_endpoint_url: bool = False, force_refresh: bool = False) -> List[Dict[str, Union[str, int, float, list, Dict]]]: + # return an empty list if workspace triad is not available. + if not subscription_id or not resource_group_name or not workspace_name: + return [] + cache_file_path = None try: with tempfile.NamedTemporaryFile(delete=False) as temp_file: @@ -598,10 +602,14 @@ def list_endpoint_names(subscription_id: str, return list_of_endpoints -def list_deployment_names(subscription_id: str, - resource_group_name: str, - workspace_name: str, +def list_deployment_names(subscription_id: str = None, + resource_group_name: str = None, + workspace_name: str = None, endpoint: str = None) -> List[Dict[str, Union[str, int, float, list, Dict]]]: + # return an empty list if workspace triad is not available. + if not subscription_id or not resource_group_name or not workspace_name: + return [] + deployment_default_list = [{ "value": DEPLOYMENT_DEFAULT, "display_value": DEPLOYMENT_DEFAULT, diff --git a/src/promptflow/promptflow/_cli/_pf/_experiment.py b/src/promptflow/promptflow/_cli/_pf/_experiment.py index 38a72c2a534..22bc107b20e 100644 --- a/src/promptflow/promptflow/_cli/_pf/_experiment.py +++ b/src/promptflow/promptflow/_cli/_pf/_experiment.py @@ -131,13 +131,13 @@ def add_experiment_start(subparsers): # Start a named experiment: pf experiment start -n my_experiment --inputs data1=data1_val data2=data2_val # Run an experiment by yaml file: - pf experiment start --file path/to/my_experiment.exp.yaml --inputs data1=data1_val data2=data2_val + pf experiment start --template path/to/my_experiment.exp.yaml --inputs data1=data1_val data2=data2_val """ activate_action( name="start", description="Start an experiment.", epilog=epilog, - add_params=[add_param_name, add_param_file, add_param_input, add_param_stream] + base_params, + add_params=[add_param_name, add_param_template, add_param_input, add_param_stream] + base_params, subparsers=subparsers, help_message="Start an experiment.", action_param_name="sub_action", @@ -235,20 +235,18 @@ def start_experiment(args: argparse.Namespace): if args.name: logger.debug(f"Starting a named experiment {args.name}.") inputs = list_of_dict_to_dict(args.inputs) - if inputs: - logger.warning("The inputs of named experiment cannot be modified.") client = _get_pf_client() experiment = client._experiments.get(args.name) - result = client._experiments.start(experiment=experiment, stream=args.stream) - elif args.file: + result = client._experiments.start(experiment=experiment, inputs=inputs, stream=args.stream) + elif args.template: from promptflow._sdk._load_functions import _load_experiment - logger.debug(f"Starting an anonymous experiment {args.file}.") - experiment = _load_experiment(source=args.file) + logger.debug(f"Starting an anonymous experiment {args.template}.") + experiment = _load_experiment(source=args.template) inputs = list_of_dict_to_dict(args.inputs) result = _get_pf_client()._experiments.start(experiment=experiment, inputs=inputs, stream=args.stream) else: - raise UserErrorException("To start an experiment, one of [name, file] must be specified.") + raise UserErrorException("To start an experiment, one of [name, template] must be specified.") print(json.dumps(result._to_dict(), indent=4)) diff --git a/src/promptflow/promptflow/_constants.py b/src/promptflow/promptflow/_constants.py index b38c6164e5f..e129d155e6b 100644 --- a/src/promptflow/promptflow/_constants.py +++ b/src/promptflow/promptflow/_constants.py @@ -10,6 +10,7 @@ PROMPTFLOW_CONNECTIONS = "PROMPTFLOW_CONNECTIONS" PROMPTFLOW_SECRETS_FILE = "PROMPTFLOW_SECRETS_FILE" PF_NO_INTERACTIVE_LOGIN = "PF_NO_INTERACTIVE_LOGIN" +PF_RUN_AS_BUILT_BINARY = "PF_RUN_AS_BUILT_BINARY" PF_LOGGING_LEVEL = "PF_LOGGING_LEVEL" OPENAI_API_KEY = "openai-api-key" BING_API_KEY = "bing-api-key" @@ -19,6 +20,7 @@ ERROR_RESPONSE_COMPONENT_NAME = "promptflow" EXTENSION_UA = "prompt-flow-extension" LANGUAGE_KEY = "language" +USER_AGENT_OVERRIDE_KEY = "user_agent_override" # Tool meta info ICON_DARK = "icon_dark" @@ -166,6 +168,13 @@ class MessageFormatType: DEFAULT_OUTPUT_NAME = "output" +OUTPUT_FILE_NAME = "output.jsonl" + + +class OutputsFolderName: + FLOW_OUTPUTS = "flow_outputs" + FLOW_ARTIFACTS = "flow_artifacts" + NODE_ARTIFACTS = "node_artifacts" class ConnectionType(str, Enum): diff --git a/src/promptflow/promptflow/_core/operation_context.py b/src/promptflow/promptflow/_core/operation_context.py index 9fc6804e99d..ddb5ef067c7 100644 --- a/src/promptflow/promptflow/_core/operation_context.py +++ b/src/promptflow/promptflow/_core/operation_context.py @@ -153,6 +153,9 @@ def append_user_agent(self, user_agent: str): user_agent (str): The user agent information to append. """ if OperationContext.USER_AGENT_KEY in self: + # TODO: this judgement can be wrong when an user agent is a substring of another, + # e.g. "Mozilla/5.0" and "Mozilla/5.0 (Windows NT 10.0; Win64; x64)" + # however, changing this code may impact existing logic, so won't change it now if user_agent not in self.user_agent: self.user_agent = f"{self.user_agent.strip()} {user_agent.strip()}" else: diff --git a/src/promptflow/promptflow/_sdk/_load_functions.py b/src/promptflow/promptflow/_sdk/_load_functions.py index 6df42545ef4..f4ebdc1a7ce 100644 --- a/src/promptflow/promptflow/_sdk/_load_functions.py +++ b/src/promptflow/promptflow/_sdk/_load_functions.py @@ -1,7 +1,6 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -import hashlib from os import PathLike from pathlib import Path from typing import IO, AnyStr, Optional, Union @@ -11,7 +10,6 @@ from .._utils.logger_utils import get_cli_sdk_logger from .._utils.yaml_utils import load_yaml from ._errors import MultipleExperimentTemplateError, NoExperimentTemplateError -from ._utils import _sanitize_python_variable_name from .entities import Run from .entities._connection import CustomConnection, _Connection from .entities._experiment import Experiment, ExperimentTemplate @@ -205,8 +203,5 @@ def _load_experiment( absolute_path = source.resolve().absolute().as_posix() if not source.exists(): raise NoExperimentTemplateError(f"Experiment file {absolute_path} not found.") - anonymous_exp_name = _sanitize_python_variable_name( - f"{source.stem}_{hashlib.sha1(absolute_path.encode('utf-8')).hexdigest()}" - ) - experiment = load_common(Experiment, source, params_override=[{"name": anonymous_exp_name}], **kwargs) + experiment = load_common(Experiment, source, **kwargs) return experiment diff --git a/src/promptflow/promptflow/_sdk/_pf_client.py b/src/promptflow/promptflow/_sdk/_pf_client.py index c6f99a0ef6e..0f1a8e3b7ad 100644 --- a/src/promptflow/promptflow/_sdk/_pf_client.py +++ b/src/promptflow/promptflow/_sdk/_pf_client.py @@ -6,12 +6,14 @@ from pathlib import Path from typing import Any, Dict, List, Union +from .._constants import USER_AGENT_OVERRIDE_KEY from .._utils.logger_utils import get_cli_sdk_logger +from ..exceptions import ErrorTarget, UserErrorException from ._configuration import Configuration -from ._constants import MAX_SHOW_DETAILS_RESULTS +from ._constants import MAX_SHOW_DETAILS_RESULTS, ConnectionProvider from ._load_functions import load_flow from ._user_agent import USER_AGENT -from ._utils import ClientUserAgentUtil, get_connection_operation, setup_user_agent_to_operation_context +from ._utils import ClientUserAgentUtil, setup_user_agent_to_operation_context from .entities import Run from .entities._eager_flow import EagerFlow from .operations import RunOperations @@ -34,20 +36,25 @@ class PFClient: def __init__(self, **kwargs): logger.debug("PFClient init with kwargs: %s", kwargs) - self._runs = RunOperations(self) + # when this is set, telemetry from this client will use this user agent and ignore the one from OperationContext + self._user_agent_override = kwargs.pop(USER_AGENT_OVERRIDE_KEY, None) self._connection_provider = kwargs.pop("connection_provider", None) self._config = kwargs.get("config", None) or {} # The credential is used as an option to override # DefaultAzureCredential when using workspace connection provider self._credential = kwargs.get("credential", None) + + # user_agent_override will be applied to all TelemetryMixin operations + self._runs = RunOperations(self, user_agent_override=self._user_agent_override) + self._flows = FlowOperations(client=self, user_agent_override=self._user_agent_override) + self._experiments = ExperimentOperations(self, user_agent_override=self._user_agent_override) # Lazy init to avoid azure credential requires too early self._connections = None - self._flows = FlowOperations(client=self) + self._tools = ToolOperations() # add user agent from kwargs if any if isinstance(kwargs.get("user_agent"), str): ClientUserAgentUtil.append_user_agent(kwargs["user_agent"]) - self._experiments = ExperimentOperations(self) self._traces = TraceOperations() setup_user_agent_to_operation_context(USER_AGENT) @@ -243,9 +250,41 @@ def connections(self) -> ConnectionOperations: """Connection operations that can manage connections.""" if not self._connections: self._ensure_connection_provider() - self._connections = get_connection_operation(self._connection_provider, self._credential) + self._connections = PFClient._build_connection_operation( + self._connection_provider, + self._credential, + user_agent_override=self._user_agent_override, + ) return self._connections + @staticmethod + def _build_connection_operation(connection_provider: str, credential=None, **kwargs): + """ + Build a ConnectionOperation object based on connection provider. + + :param connection_provider: Connection provider, e.g. local, azureml, azureml://subscriptions..., etc. + :type connection_provider: str + :param credential: Credential when remote provider, default to chained credential DefaultAzureCredential. + :type credential: object + """ + if connection_provider == ConnectionProvider.LOCAL.value: + from promptflow._sdk.operations._connection_operations import ConnectionOperations + + logger.debug("PFClient using local connection operations.") + connection_operation = ConnectionOperations(**kwargs) + elif connection_provider.startswith(ConnectionProvider.AZUREML.value): + from promptflow._sdk.operations._local_azure_connection_operations import LocalAzureConnectionOperations + + logger.debug(f"PFClient using local azure connection operations with credential {credential}.") + connection_operation = LocalAzureConnectionOperations(connection_provider, credential=credential, **kwargs) + else: + raise UserErrorException( + target=ErrorTarget.CONTROL_PLANE_SDK, + message_format="Unsupported connection provider: {connection_provider}", + connection_provider=connection_provider, + ) + return connection_operation + @property def flows(self) -> FlowOperations: """Operations on the flow that can manage flows.""" diff --git a/src/promptflow/promptflow/_sdk/_service/apis/collector.py b/src/promptflow/promptflow/_sdk/_service/apis/collector.py index 0a31761074b..647f8a64a11 100644 --- a/src/promptflow/promptflow/_sdk/_service/apis/collector.py +++ b/src/promptflow/promptflow/_sdk/_service/apis/collector.py @@ -11,6 +11,7 @@ import logging import traceback from datetime import datetime +from typing import Callable from flask import request from google.protobuf.json_format import MessageToJson @@ -28,7 +29,15 @@ from promptflow._utils.thread_utils import ThreadWithContextVars -def trace_collector(logger: logging.Logger): +def trace_collector(get_created_by_info_with_cache: Callable, logger: logging.Logger): + """ + This function is target to be reused in other places, so pass in get_created_by_info_with_cache and logger to avoid + app related dependencies. + + Args: + get_created_by_info_with_cache (Callable): A function that retrieves information about the creator of the trace. + logger (logging.Logger): The logger object used for logging. + """ content_type = request.headers.get("Content-Type") # binary protobuf encoding if "application/x-protobuf" in content_type: @@ -55,7 +64,9 @@ def trace_collector(logger: logging.Logger): all_spans.append(span) # Create a new thread to write trace to cosmosdb to avoid blocking the main thread - ThreadWithContextVars(target=_try_write_trace_to_cosmosdb, args=(all_spans, logger)).start() + ThreadWithContextVars( + target=_try_write_trace_to_cosmosdb, args=(all_spans, get_created_by_info_with_cache, logger) + ).start() return "Traces received", 200 # JSON protobuf encoding @@ -63,7 +74,7 @@ def trace_collector(logger: logging.Logger): raise NotImplementedError -def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger): +def _try_write_trace_to_cosmosdb(all_spans, get_created_by_info_with_cache: Callable, logger: logging.Logger): if not all_spans: return try: @@ -78,31 +89,37 @@ def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger): logger.info(f"Start writing trace to cosmosdb, total spans count: {len(all_spans)}.") start_time = datetime.now() - from promptflow._sdk._service.app import CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE from promptflow.azure._storage.cosmosdb.client import get_client from promptflow.azure._storage.cosmosdb.span import Span as SpanCosmosDB from promptflow.azure._storage.cosmosdb.summary import Summary # Load span and summary clients first time may slow. # So, we load 2 client in parallel for warm up. - span_thread = ThreadWithContextVars( + span_client_thread = ThreadWithContextVars( target=get_client, args=(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name) ) - span_thread.start() + span_client_thread.start() + + # Load created_by info first time may slow. So, we load it in parallel for warm up. + created_by_thread = ThreadWithContextVars(target=get_created_by_info_with_cache) + created_by_thread.start() get_client(CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name) - span_thread.join() + span_client_thread.join() + created_by_thread.join() + + created_by = get_created_by_info_with_cache() for span in all_spans: span_client = get_client(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name) - result = SpanCosmosDB(span, CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE).persist(span_client) + result = SpanCosmosDB(span, created_by).persist(span_client) # None means the span already exists, then we don't need to persist the summary also. if result is not None: line_summary_client = get_client( CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name ) - Summary(span, CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE, logger).persist(line_summary_client) + Summary(span, created_by, logger).persist(line_summary_client) logger.info( ( f"Finish writing trace to cosmosdb, total spans count: {len(all_spans)}." diff --git a/src/promptflow/promptflow/_sdk/_service/apis/connection.py b/src/promptflow/promptflow/_sdk/_service/apis/connection.py index 831930d9b67..1d670237523 100644 --- a/src/promptflow/promptflow/_sdk/_service/apis/connection.py +++ b/src/promptflow/promptflow/_sdk/_service/apis/connection.py @@ -10,7 +10,7 @@ import promptflow._sdk.schemas._connection as connection from promptflow._sdk._configuration import Configuration from promptflow._sdk._service import Namespace, Resource, fields -from promptflow._sdk._service.utils.utils import build_pfs_user_agent, local_user_only, make_response_no_content +from promptflow._sdk._service.utils.utils import get_client_from_request, local_user_only, make_response_no_content from promptflow._sdk.entities._connection import _Connection api = Namespace("Connections", description="Connections Management") @@ -66,14 +66,10 @@ def validate_working_directory(value): def _get_connection_operation(working_directory=None): - from promptflow._sdk._pf_client import PFClient - connection_provider = Configuration().get_connection_provider(path=working_directory) # get_connection_operation is a shared function, so we build user agent based on request first and # then pass it to the function - connection_operation = PFClient( - connection_provider=connection_provider, user_agent=build_pfs_user_agent() - ).connections + connection_operation = get_client_from_request(connection_provider=connection_provider).connections return connection_operation diff --git a/src/promptflow/promptflow/_sdk/_service/app.py b/src/promptflow/promptflow/_sdk/_service/app.py index 30993d0da05..385543ec149 100644 --- a/src/promptflow/promptflow/_sdk/_service/app.py +++ b/src/promptflow/promptflow/_sdk/_service/app.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- import logging -import sys +import threading import time from datetime import datetime, timedelta from logging.handlers import RotatingFileHandler @@ -29,6 +29,7 @@ FormattedException, get_current_env_pfs_file, get_port_from_config, + is_run_from_built_binary, kill_exist_service, ) from promptflow._sdk._utils import get_promptflow_sdk_version, overwrite_null_std_logger, read_write_by_user @@ -42,9 +43,6 @@ def heartbeat(): return jsonify(response) -CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE = {} - - def create_app(): app = Flask(__name__) @@ -54,7 +52,9 @@ def create_app(): CORS(app) app.add_url_rule("/heartbeat", view_func=heartbeat) - app.add_url_rule("/v1/traces", view_func=lambda: trace_collector(app.logger), methods=["POST"]) + app.add_url_rule( + "/v1/traces", view_func=lambda: trace_collector(get_created_by_info_with_cache, app.logger), methods=["POST"] + ) with app.app_context(): api_v1 = Blueprint("Prompt Flow Service", __name__, url_prefix="/v1.0") @@ -74,7 +74,7 @@ def create_app(): # Enable log app.logger.setLevel(logging.INFO) # each env will have its own log file - if sys.executable.endswith("pfcli.exe"): + if is_run_from_built_binary(): log_file = HOME_PROMPT_FLOW_DIR / PF_SERVICE_LOG_FILE log_file.touch(mode=read_write_by_user(), exist_ok=True) else: @@ -86,33 +86,6 @@ def create_app(): # Set app logger to the only one RotatingFileHandler to avoid duplicate logs app.logger.handlers = [handler] - def initialize_created_by_info(): - from promptflow._sdk._configuration import Configuration - from promptflow._sdk._utils import extract_workspace_triad_from_trace_provider - - trace_provider = Configuration.get_instance().get_trace_provider() - if trace_provider is None or extract_workspace_triad_from_trace_provider(trace_provider) is None: - return - try: - import jwt - from azure.identity import DefaultAzureCredential - - from promptflow.azure._utils.general import get_arm_token - - default_credential = DefaultAzureCredential() - - token = get_arm_token(credential=default_credential) - decoded_token = jwt.decode(token, options={"verify_signature": False}) - user_object_id, user_tenant_id = decoded_token["oid"], decoded_token["tid"] - CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE.update( - { - "object_id": user_object_id, - "tenant_id": user_tenant_id, - } - ) - except Exception as e: - current_app.logger.error(f"Failed to get created_by info, ignore it: {e}") - # Basic error handler @api.errorhandler(Exception) def handle_exception(e): @@ -167,8 +140,42 @@ def monitor_request(): kill_exist_service(port) break - initialize_created_by_info() - if not sys.executable.endswith("pfcli.exe"): + if not is_run_from_built_binary(): monitor_thread = ThreadWithContextVars(target=monitor_request, daemon=True) monitor_thread.start() return app, api + + +created_by_for_local_to_cloud_trace = {} +created_by_for_local_to_cloud_trace_lock = threading.Lock() + + +def get_created_by_info_with_cache(): + if len(created_by_for_local_to_cloud_trace) > 0: + return created_by_for_local_to_cloud_trace + with created_by_for_local_to_cloud_trace_lock: + if len(created_by_for_local_to_cloud_trace) > 0: + return created_by_for_local_to_cloud_trace + try: + # The total time of collecting info is about 3s. + import jwt + from azure.identity import DefaultAzureCredential + + from promptflow.azure._utils.general import get_arm_token + + default_credential = DefaultAzureCredential() + + token = get_arm_token(credential=default_credential) + decoded_token = jwt.decode(token, options={"verify_signature": False}) + created_by_for_local_to_cloud_trace.update( + { + "object_id": decoded_token["oid"], + "tenant_id": decoded_token["tid"], + # Use appid as fallback for service principal scenario. + "name": decoded_token.get("name", decoded_token.get("appid", "")), + } + ) + except Exception as e: + # This function is only target to be used in Flask app. + current_app.logger.error(f"Failed to get created_by info, ignore it: {e}") + return created_by_for_local_to_cloud_trace diff --git a/src/promptflow/promptflow/_sdk/_service/entry.py b/src/promptflow/promptflow/_sdk/_service/entry.py index 5514224e977..564480a634d 100644 --- a/src/promptflow/promptflow/_sdk/_service/entry.py +++ b/src/promptflow/promptflow/_sdk/_service/entry.py @@ -21,6 +21,7 @@ get_port_from_config, get_started_service_info, is_port_in_use, + is_run_from_built_binary, kill_exist_service, ) from promptflow._sdk._utils import get_promptflow_sdk_version, print_pf_version @@ -104,7 +105,7 @@ def validate_port(port, force_start): port = get_port_from_config(create_if_not_exists=True) validate_port(port, args.force) - if sys.executable.endswith("pfcli.exe"): + if is_run_from_built_binary(): # For msi installer, use sdk api to start pfs since it's not supported to invoke waitress by cli directly # after packaged by Pyinstaller. app, _ = create_app() diff --git a/src/promptflow/promptflow/_sdk/_service/utils/utils.py b/src/promptflow/promptflow/_sdk/_service/utils/utils.py index 1cdafa39912..5e8d0a25e24 100644 --- a/src/promptflow/promptflow/_sdk/_service/utils/utils.py +++ b/src/promptflow/promptflow/_sdk/_service/utils/utils.py @@ -17,6 +17,7 @@ import requests from flask import abort, make_response, request +from promptflow._constants import PF_RUN_AS_BUILT_BINARY from promptflow._sdk._constants import ( DEFAULT_ENCODING, HOME_PROMPT_FLOW_DIR, @@ -60,7 +61,7 @@ def get_current_env_pfs_file(file_name): def get_port_from_config(create_if_not_exists=False): - if sys.executable.endswith("pfcli.exe"): + if is_run_from_built_binary(): port_file_path = HOME_PROMPT_FLOW_DIR / PF_SERVICE_PORT_FILE port_file_path.touch(mode=read_write_by_user(), exist_ok=True) else: @@ -79,7 +80,7 @@ def get_port_from_config(create_if_not_exists=False): def dump_port_to_config(port): - if sys.executable.endswith("pfcli.exe"): + if is_run_from_built_binary(): port_file_path = HOME_PROMPT_FLOW_DIR / PF_SERVICE_PORT_FILE port_file_path.touch(mode=read_write_by_user(), exist_ok=True) else: @@ -231,13 +232,34 @@ def __post_init__(self, exception, status_code): def build_pfs_user_agent(): - extra_agent = f"local_pfs/{VERSION}" - if request.user_agent.string: - return f"{request.user_agent.string} {extra_agent}" - return extra_agent + user_agent = request.user_agent.string + user_agent_for_local_pfs = f"local_pfs/{VERSION}" + if user_agent: + return f"{user_agent} {user_agent_for_local_pfs}" + return user_agent_for_local_pfs -def get_client_from_request() -> "PFClient": +def get_client_from_request(*, connection_provider=None) -> "PFClient": + """ + Build a PFClient instance based on current request in local PFS. + + User agent may be different for each request. + """ from promptflow._sdk._pf_client import PFClient - return PFClient(user_agent=build_pfs_user_agent()) + user_agent = build_pfs_user_agent() + + if connection_provider: + pf_client = PFClient(connection_provider=connection_provider, user_agent_override=user_agent) + else: + pf_client = PFClient(user_agent_override=user_agent) + return pf_client + + +def is_run_from_built_binary(): + """ + Use this function to trigger behavior difference between calling from promptflow sdk/cli and built binary. + + Allow customer to use environment variable to control the triggering. + """ + return sys.executable.endswith("pfcli.exe") or os.environ.get(PF_RUN_AS_BUILT_BINARY, "").lower() == "true" diff --git a/src/promptflow/promptflow/_sdk/_telemetry/activity.py b/src/promptflow/promptflow/_sdk/_telemetry/activity.py index 0ffda098cbf..4693780c67e 100644 --- a/src/promptflow/promptflow/_sdk/_telemetry/activity.py +++ b/src/promptflow/promptflow/_sdk/_telemetry/activity.py @@ -105,6 +105,7 @@ def log_activity( activity_name, activity_type=ActivityType.INTERNALCALL, custom_dimensions=None, + user_agent=None, ): """Log an activity. @@ -121,12 +122,16 @@ def log_activity( :type activity_type: str :param custom_dimensions: The custom properties of the activity. :type custom_dimensions: dict + :param user_agent: Specify user agent. If not specified, the user agent will be got from OperationContext. + :type user_agent: str :return: None """ if not custom_dimensions: custom_dimensions = {} - user_agent = ClientUserAgentUtil.get_user_agent() + # provided user agent will be respected even if it's "" + if user_agent is None: + user_agent = ClientUserAgentUtil.get_user_agent() request_id = request_id_context.get() if not request_id: # public function call @@ -179,17 +184,6 @@ def log_activity( raise exception -def extract_telemetry_info(self): - """Extract pf telemetry info from given telemetry mix-in instance.""" - result = {} - try: - if isinstance(self, TelemetryMixin): - return self._get_telemetry_values() - except Exception: - pass - return result - - def update_activity_name(activity_name, kwargs=None, args=None): """Update activity name according to kwargs. For flow test, we want to know if it's node test.""" if activity_name == "pf.flows.test": @@ -233,10 +227,21 @@ def wrapper(self, *args, **kwargs): logger = get_telemetry_logger() - custom_dimensions.update(extract_telemetry_info(self)) + if isinstance(self, TelemetryMixin): + custom_dimensions.update(self._get_telemetry_values()) + user_agent = self._get_user_agent_override() + else: + user_agent = None + # update activity name according to kwargs. _activity_name = update_activity_name(activity_name, kwargs=kwargs) - with log_activity(logger, _activity_name, activity_type, custom_dimensions): + with log_activity( + logger=logger, + activity_name=_activity_name, + activity_type=activity_type, + custom_dimensions=custom_dimensions, + user_agent=user_agent, + ): if _activity_name in HINT_ACTIVITY_NAME: hint_for_update() # set check_latest_version as deamon thread to avoid blocking main thread diff --git a/src/promptflow/promptflow/_sdk/_telemetry/telemetry.py b/src/promptflow/promptflow/_sdk/_telemetry/telemetry.py index 5abd5a483b0..ce05ba30039 100644 --- a/src/promptflow/promptflow/_sdk/_telemetry/telemetry.py +++ b/src/promptflow/promptflow/_sdk/_telemetry/telemetry.py @@ -2,7 +2,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- import logging +from typing import Optional +from promptflow._constants import USER_AGENT_OVERRIDE_KEY from promptflow._sdk._configuration import Configuration PROMPTFLOW_LOGGER_NAMESPACE = "promptflow._sdk._telemetry" @@ -10,17 +12,28 @@ class TelemetryMixin(object): def __init__(self, **kwargs): + self._user_agent_override = kwargs.pop(USER_AGENT_OVERRIDE_KEY, None) + # Need to call init for potential parent, otherwise it won't be initialized. + # TODO: however, object.__init__() takes exactly one argument (the instance to initialize), so this will fail + # if there are any kwargs left. super().__init__(**kwargs) def _get_telemetry_values(self, *args, **kwargs): # pylint: disable=unused-argument - """Return the telemetry values of object. + """Return the telemetry values of object, will be set as custom_dimensions in telemetry. :return: The telemetry values :rtype: Dict """ return {} + def _get_user_agent_override(self) -> Optional[str]: + """If we have a bonded user agent passed in via the constructor, return it. + + Telemetries from this object will use this user agent and ignore the one from OperationContext. + """ + return self._user_agent_override + class WorkspaceTelemetryMixin(TelemetryMixin): def __init__(self, subscription_id, resource_group_name, workspace_name, **kwargs): @@ -31,7 +44,7 @@ def __init__(self, subscription_id, resource_group_name, workspace_name, **kwarg super().__init__(**kwargs) def _get_telemetry_values(self, *args, **kwargs): # pylint: disable=unused-argument - """Return the telemetry values of run operations. + """Return the telemetry values of object, will be set as custom_dimensions in telemetry. :return: The telemetry values :rtype: Dict diff --git a/src/promptflow/promptflow/_sdk/_utils.py b/src/promptflow/promptflow/_sdk/_utils.py index 92b0a81dc8c..0d16f5ec717 100644 --- a/src/promptflow/promptflow/_sdk/_utils.py +++ b/src/promptflow/promptflow/_sdk/_utils.py @@ -56,7 +56,6 @@ VARIANTS, AzureMLWorkspaceTriad, CommonYamlFields, - ConnectionProvider, ) from promptflow._sdk._errors import ( DecryptConnectionError, @@ -1022,41 +1021,6 @@ def parse_remote_flow_pattern(flow: object) -> str: return flow_name -def get_connection_operation(connection_provider: str, credential=None, user_agent: str = None): - """ - Get connection operation based on connection provider. - This function will be called by PFClient, so please do not refer to PFClient in this function. - - :param connection_provider: Connection provider, e.g. local, azureml, azureml://subscriptions..., etc. - :type connection_provider: str - :param credential: Credential when remote provider, default to chained credential DefaultAzureCredential. - :type credential: object - :param user_agent: User Agent - :type user_agent: str - """ - if connection_provider == ConnectionProvider.LOCAL.value: - from promptflow._sdk.operations._connection_operations import ConnectionOperations - - logger.debug("PFClient using local connection operations.") - connection_operation = ConnectionOperations() - elif connection_provider.startswith(ConnectionProvider.AZUREML.value): - from promptflow._sdk.operations._local_azure_connection_operations import LocalAzureConnectionOperations - - logger.debug(f"PFClient using local azure connection operations with credential {credential}.") - if user_agent is None: - connection_operation = LocalAzureConnectionOperations(connection_provider, credential=credential) - else: - connection_operation = LocalAzureConnectionOperations(connection_provider, user_agent=user_agent) - else: - error = ValueError(f"Unsupported connection provider: {connection_provider}") - raise UserErrorException( - target=ErrorTarget.CONTROL_PLANE_SDK, - message=str(error), - error=error, - ) - return connection_operation - - # extract open read/write as partial to centralize the encoding read_open = partial(open, mode="r", encoding=DEFAULT_ENCODING) write_open = partial(open, mode="w", encoding=DEFAULT_ENCODING) diff --git a/src/promptflow/promptflow/_sdk/operations/_flow_operations.py b/src/promptflow/promptflow/_sdk/operations/_flow_operations.py index 64a850192de..d816cf945af 100644 --- a/src/promptflow/promptflow/_sdk/operations/_flow_operations.py +++ b/src/promptflow/promptflow/_sdk/operations/_flow_operations.py @@ -48,9 +48,9 @@ class FlowOperations(TelemetryMixin): """FlowOperations.""" - def __init__(self, client): + def __init__(self, client, **kwargs): + super().__init__(**kwargs) self._client = client - super().__init__() @monitor_operation(activity_name="pf.flows.test", activity_type=ActivityType.PUBLICAPI) def test( diff --git a/src/promptflow/promptflow/_sdk/operations/_local_storage_operations.py b/src/promptflow/promptflow/_sdk/operations/_local_storage_operations.py index f5ea3e0204f..0fef27a51b3 100644 --- a/src/promptflow/promptflow/_sdk/operations/_local_storage_operations.py +++ b/src/promptflow/promptflow/_sdk/operations/_local_storage_operations.py @@ -14,6 +14,7 @@ from filelock import FileLock +from promptflow._constants import OUTPUT_FILE_NAME, OutputsFolderName from promptflow._sdk._constants import ( HOME_PROMPT_FLOW_DIR, LINE_NUMBER, @@ -45,6 +46,7 @@ load_multimedia_data_recursively, resolve_multimedia_data_recursively, ) +from promptflow._utils.utils import prepare_folder from promptflow._utils.yaml_utils import load_yaml from promptflow.batch._result import BatchResult from promptflow.contracts.multimedia import Image @@ -191,13 +193,13 @@ class LocalStorageOperations(AbstractBatchRunStorage): def __init__(self, run: Run, stream=False, run_mode=RunMode.Test): self._run = run - self.path = self._prepare_folder(self._run._output_path) + self.path = prepare_folder(self._run._output_path) self.logger = LoggerOperations( file_path=self.path / LocalStorageFilenames.LOG, stream=stream, run_mode=run_mode ) # snapshot - self._snapshot_folder_path = self._prepare_folder(self.path / LocalStorageFilenames.SNAPSHOT_FOLDER) + self._snapshot_folder_path = prepare_folder(self.path / LocalStorageFilenames.SNAPSHOT_FOLDER) self._dag_path = self._snapshot_folder_path / LocalStorageFilenames.DAG self._flow_tools_json_path = ( self._snapshot_folder_path / PROMPT_FLOW_DIR_NAME / LocalStorageFilenames.FLOW_TOOLS_JSON @@ -214,10 +216,10 @@ def __init__(self, run: Run, stream=False, run_mode=RunMode.Test): # for line run records, store per line # for normal node run records, store per node per line; # for reduce node run records, store centralized in 000000000.jsonl per node - self.outputs_folder = self._prepare_folder(self.path / "flow_outputs") - self._outputs_path = self.outputs_folder / "output.jsonl" # dumped by executor - self._node_infos_folder = self._prepare_folder(self.path / "node_artifacts") - self._run_infos_folder = self._prepare_folder(self.path / "flow_artifacts") + self.outputs_folder = prepare_folder(self.path / OutputsFolderName.FLOW_OUTPUTS) + self._outputs_path = self.outputs_folder / OUTPUT_FILE_NAME # dumped by executor + self._node_infos_folder = prepare_folder(self.path / OutputsFolderName.NODE_ARTIFACTS) + self._run_infos_folder = prepare_folder(self.path / OutputsFolderName.FLOW_ARTIFACTS) self._data_path = Path(run.data) if run.data is not None else None self._meta_path = self.path / LocalStorageFilenames.META @@ -379,7 +381,7 @@ def load_metrics(self, *, parse_const_as_str: bool = False) -> Dict[str, Union[i def persist_node_run(self, run_info: NodeRunInfo) -> None: """Persist node run record to local storage.""" - node_folder = self._prepare_folder(self._node_infos_folder / run_info.node) + node_folder = prepare_folder(self._node_infos_folder / run_info.node) self._persist_run_multimedia(run_info, node_folder) node_run_record = NodeRunRecord.from_run_info(run_info) # for reduce nodes, the line_number is None, store the info in the 000000000.jsonl @@ -482,12 +484,6 @@ def _serialize_multimedia(self, value, folder_path: Path, relative_path: Path = serialization_funcs = {Image: partial(Image.serialize, **{"encoder": pfbytes_file_reference_encoder})} return serialize(value, serialization_funcs=serialization_funcs) - @staticmethod - def _prepare_folder(path: Union[str, Path]) -> Path: - path = Path(path) - path.mkdir(parents=True, exist_ok=True) - return path - @staticmethod def _outputs_padding(df: "DataFrame", inputs_line_numbers: List[int]) -> "DataFrame": import pandas as pd diff --git a/src/promptflow/promptflow/_utils/utils.py b/src/promptflow/promptflow/_utils/utils.py index fe155f0fee4..844913829e7 100644 --- a/src/promptflow/promptflow/_utils/utils.py +++ b/src/promptflow/promptflow/_utils/utils.py @@ -389,3 +389,10 @@ def in_jupyter_notebook() -> bool: def snake_to_camel(name): return re.sub(r"(?:^|_)([a-z])", lambda x: x.group(1).upper(), name) + + +def prepare_folder(path: Union[str, Path]) -> Path: + """Create folder if not exists and return the folder path.""" + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + return path diff --git a/src/promptflow/promptflow/batch/_batch_engine.py b/src/promptflow/promptflow/batch/_batch_engine.py index 47e88c9e28d..f154d4e16b4 100644 --- a/src/promptflow/promptflow/batch/_batch_engine.py +++ b/src/promptflow/promptflow/batch/_batch_engine.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Any, Dict, List, Mapping, Optional -from promptflow._constants import LANGUAGE_KEY, LINE_NUMBER_KEY, LINE_TIMEOUT_SEC, FlowLanguage +from promptflow._constants import LANGUAGE_KEY, LINE_NUMBER_KEY, LINE_TIMEOUT_SEC, OUTPUT_FILE_NAME, FlowLanguage from promptflow._core._errors import ResumeCopyError, UnexpectedError from promptflow._core.operation_context import OperationContext from promptflow._utils.async_utils import async_run_allowing_running_loop @@ -46,7 +46,6 @@ from promptflow.executor.flow_validator import FlowValidator from promptflow.storage import AbstractBatchRunStorage, AbstractRunStorage -OUTPUT_FILE_NAME = "output.jsonl" DEFAULT_CONCURRENCY = 10 @@ -239,13 +238,13 @@ def _copy_previous_run_result( return the list of previous line results for the usage of aggregation and summarization. """ # Load the previous flow run output from output.jsonl - previous_run_output = load_list_from_jsonl(resume_from_run_output_dir / "output.jsonl") + previous_run_output = load_list_from_jsonl(resume_from_run_output_dir / OUTPUT_FILE_NAME) previous_run_output_dict = { each_line_output[LINE_NUMBER_KEY]: each_line_output for each_line_output in previous_run_output } # Copy other files from resume_from_run_output_dir to output_dir in case there are images - copy_file_except(resume_from_run_output_dir, output_dir, "output.jsonl") + copy_file_except(resume_from_run_output_dir, output_dir, OUTPUT_FILE_NAME) try: previous_run_results = [] diff --git a/src/promptflow/tests/executor/e2etests/test_activate.py b/src/promptflow/tests/executor/e2etests/test_activate.py index aa7ab9c7a96..2791ecdfcaf 100644 --- a/src/promptflow/tests/executor/e2etests/test_activate.py +++ b/src/promptflow/tests/executor/e2etests/test_activate.py @@ -4,8 +4,9 @@ import pytest +from promptflow._constants import OUTPUT_FILE_NAME from promptflow._utils.logger_utils import LogContext -from promptflow.batch._batch_engine import OUTPUT_FILE_NAME, BatchEngine +from promptflow.batch._batch_engine import BatchEngine from promptflow.batch._result import BatchResult from promptflow.contracts._errors import FlowDefinitionError from promptflow.contracts.run_info import FlowRunInfo diff --git a/src/promptflow/tests/executor/e2etests/test_batch_engine.py b/src/promptflow/tests/executor/e2etests/test_batch_engine.py index e3f77e73c50..0617105f6b1 100644 --- a/src/promptflow/tests/executor/e2etests/test_batch_engine.py +++ b/src/promptflow/tests/executor/e2etests/test_batch_engine.py @@ -8,10 +8,11 @@ import pytest +from promptflow._constants import OUTPUT_FILE_NAME from promptflow._sdk.entities._run import Run from promptflow._sdk.operations._local_storage_operations import LocalStorageOperations from promptflow._utils.utils import dump_list_to_jsonl -from promptflow.batch._batch_engine import OUTPUT_FILE_NAME, BatchEngine +from promptflow.batch._batch_engine import BatchEngine from promptflow.batch._errors import EmptyInputsData from promptflow.batch._result import BatchResult from promptflow.contracts.run_info import Status diff --git a/src/promptflow/tests/executor/e2etests/test_eager_flow.py b/src/promptflow/tests/executor/e2etests/test_eager_flow.py index 39bfc1d4103..b412843af39 100644 --- a/src/promptflow/tests/executor/e2etests/test_eager_flow.py +++ b/src/promptflow/tests/executor/e2etests/test_eager_flow.py @@ -4,7 +4,8 @@ import pytest -from promptflow.batch._batch_engine import OUTPUT_FILE_NAME, BatchEngine +from promptflow._constants import OUTPUT_FILE_NAME +from promptflow.batch._batch_engine import BatchEngine from promptflow.batch._result import BatchResult, LineResult from promptflow.contracts.run_info import Status from promptflow.executor._script_executor import ScriptExecutor diff --git a/src/promptflow/tests/executor/e2etests/test_image.py b/src/promptflow/tests/executor/e2etests/test_image.py index 7370648982f..b29efaf0e69 100644 --- a/src/promptflow/tests/executor/e2etests/test_image.py +++ b/src/promptflow/tests/executor/e2etests/test_image.py @@ -4,8 +4,9 @@ import pytest +from promptflow._constants import OUTPUT_FILE_NAME from promptflow._utils.multimedia_utils import MIME_PATTERN, _create_image_from_file, _is_url, is_multimedia_dict -from promptflow.batch._batch_engine import OUTPUT_FILE_NAME, BatchEngine +from promptflow.batch._batch_engine import BatchEngine from promptflow.batch._result import BatchResult from promptflow.contracts.multimedia import Image from promptflow.contracts.run_info import FlowRunInfo, RunInfo, Status diff --git a/src/promptflow/tests/executor/e2etests/test_logs.py b/src/promptflow/tests/executor/e2etests/test_logs.py index 517ab0d93c7..89a0b6f47bc 100644 --- a/src/promptflow/tests/executor/e2etests/test_logs.py +++ b/src/promptflow/tests/executor/e2etests/test_logs.py @@ -3,6 +3,7 @@ import pytest +from promptflow._constants import OUTPUT_FILE_NAME from promptflow._utils.logger_utils import LogContext from promptflow.batch import BatchEngine from promptflow.batch._result import BatchResult @@ -21,7 +22,6 @@ TEST_LOGS_FLOW = ["print_input_flow"] SAMPLE_FLOW_WITH_TEN_INPUTS = "simple_flow_with_ten_inputs" -OUTPUT_FILE_NAME = "output.jsonl" def submit_batch_run( diff --git a/src/promptflow/tests/executor/e2etests/test_telemetry.py b/src/promptflow/tests/executor/e2etests/test_telemetry.py index a1f838c69e5..d2005398ec7 100644 --- a/src/promptflow/tests/executor/e2etests/test_telemetry.py +++ b/src/promptflow/tests/executor/e2etests/test_telemetry.py @@ -8,8 +8,9 @@ import pytest +from promptflow._constants import OUTPUT_FILE_NAME from promptflow._core.operation_context import OperationContext -from promptflow.batch._batch_engine import OUTPUT_FILE_NAME, BatchEngine +from promptflow.batch._batch_engine import BatchEngine from promptflow.batch._result import BatchResult from promptflow.contracts.run_mode import RunMode from promptflow.executor import FlowExecutor diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py index 8186a51b328..f3f5f9911f5 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_cli.py @@ -1985,18 +1985,20 @@ def wait_for_experiment_terminated(experiment_name): @pytest.mark.skipif(condition=not is_live(), reason="Injection cannot passed to detach process.") @pytest.mark.usefixtures("setup_experiment_table") def test_experiment_start_anonymous_experiment(self, monkeypatch, local_client): - from promptflow._sdk._load_functions import _load_experiment - with mock.patch("promptflow._sdk._configuration.Configuration.is_internal_features_enabled") as mock_func: - mock_func.return_value = True - experiment_file = f"{EXPERIMENT_DIR}/basic-script-template/basic-script.exp.yaml" - run_pf_command("experiment", "start", "--file", experiment_file, "--stream") - experiment = _load_experiment(source=experiment_file) - exp = local_client._experiments.get(name=experiment.name) - assert len(exp.node_runs) == 4 - assert all(len(exp.node_runs[node_name]) > 0 for node_name in exp.node_runs) - metrics = local_client.runs.get_metrics(name=exp.node_runs["eval"][0]["name"]) - assert "accuracy" in metrics + from promptflow._sdk.entities._experiment import Experiment + + with mock.patch.object(Experiment, "_generate_name") as mock_generate_name: + experiment_name = str(uuid.uuid4()) + mock_generate_name.return_value = experiment_name + mock_func.return_value = True + experiment_file = f"{EXPERIMENT_DIR}/basic-script-template/basic-script.exp.yaml" + run_pf_command("experiment", "start", "--template", experiment_file, "--stream") + exp = local_client._experiments.get(name=experiment_name) + assert len(exp.node_runs) == 4 + assert all(len(exp.node_runs[node_name]) > 0 for node_name in exp.node_runs) + metrics = local_client.runs.get_metrics(name=exp.node_runs["eval"][0]["name"]) + assert "accuracy" in metrics @pytest.mark.usefixtures("setup_experiment_table", "recording_injection") def test_experiment_test(self, monkeypatch, capfd, local_client, tmpdir): diff --git a/src/promptflow/tests/sdk_pfs_test/e2etests/test_connection_apis.py b/src/promptflow/tests/sdk_pfs_test/e2etests/test_connection_apis.py index ee05581e848..2ac3a1a2993 100644 --- a/src/promptflow/tests/sdk_pfs_test/e2etests/test_connection_apis.py +++ b/src/promptflow/tests/sdk_pfs_test/e2etests/test_connection_apis.py @@ -32,6 +32,22 @@ def test_list_connections(self, pf_client: PFClient, pfs_op: PFSOperations) -> N assert len(connections) >= 1 + def test_list_connections_with_different_user_agent(self, pf_client: PFClient, pfs_op: PFSOperations) -> None: + create_custom_connection(pf_client) + base_user_agent = ["local_pfs/0.0.1"] + for _, extra_user_agent in enumerate( + [ + ["another_test_user_agent/0.0.1"], + ["test_user_agent/0.0.1"], + ["another_test_user_agent/0.0.1"], + ["test_user_agent/0.0.1"], + ] + ): + with check_activity_end_telemetry( + activity_name="pf.connections.list", user_agent=base_user_agent + extra_user_agent + ): + pfs_op.list_connections(user_agent=extra_user_agent) + def test_get_connection(self, pf_client: PFClient, pfs_op: PFSOperations) -> None: name = create_custom_connection(pf_client) with check_activity_end_telemetry(activity_name="pf.connections.get"): diff --git a/src/promptflow/tests/sdk_pfs_test/utils.py b/src/promptflow/tests/sdk_pfs_test/utils.py index 3d1c4c65fff..7ebfb202216 100644 --- a/src/promptflow/tests/sdk_pfs_test/utils.py +++ b/src/promptflow/tests/sdk_pfs_test/utils.py @@ -31,7 +31,7 @@ def check_activity_end_telemetry( "first_call": True, "activity_type": "PublicApi", "completion_status": "Success", - "user_agent": f"promptflow-sdk/0.0.1 Werkzeug/{werkzeug.__version__} local_pfs/0.0.1", + "user_agent": [f"Werkzeug/{werkzeug.__version__}", "local_pfs/0.0.1"], } for i, expected_activity in enumerate(expected_activities): temp = default_expected_call.copy() @@ -39,6 +39,9 @@ def check_activity_end_telemetry( expected_activity = temp for key, expected_value in expected_activity.items(): value = actual_activities[i][key] + if isinstance(expected_value, list): + value = list(sorted(value.split(" "))) + expected_value = list(sorted(expected_value)) assert ( value == expected_value ), f"{key} mismatch in {i+1}th call: expect {expected_value} but got {value}" @@ -54,7 +57,12 @@ class PFSOperations: def __init__(self, client: FlaskClient): self._client = client - def remote_user_header(self): + def remote_user_header(self, user_agent=None): + if user_agent: + return { + "X-Remote-User": getpass.getuser(), + "User-Agent": user_agent, + } return {"X-Remote-User": getpass.getuser()} def heartbeat(self): @@ -67,8 +75,10 @@ def connection_operation_with_invalid_user(self, status_code=None): assert status_code == response.status_code, response.text return response - def list_connections(self, status_code=None): - response = self._client.get(f"{self.CONNECTION_URL_PREFIX}/", headers=self.remote_user_header()) + def list_connections(self, status_code=None, user_agent=None): + response = self._client.get( + f"{self.CONNECTION_URL_PREFIX}/", headers=self.remote_user_header(user_agent=user_agent) + ) if status_code: assert status_code == response.status_code, response.text return response