From 9d307e057e35197f15f3d3aac0682a31023a5804 Mon Sep 17 00:00:00 2001 From: zhangxingzhi Date: Tue, 12 Mar 2024 10:14:08 +0800 Subject: [PATCH] refactor: initialize bonded user agent in constructor of PFClient --- src/promptflow/promptflow/_constants.py | 1 + src/promptflow/promptflow/_sdk/_pf_client.py | 60 ++++++++++++++++--- .../promptflow/_sdk/_service/utils/utils.py | 8 +-- .../promptflow/_sdk/_telemetry/telemetry.py | 13 +--- src/promptflow/promptflow/_sdk/_utils.py | 36 ----------- .../_sdk/operations/_flow_operations.py | 4 +- 6 files changed, 61 insertions(+), 61 deletions(-) diff --git a/src/promptflow/promptflow/_constants.py b/src/promptflow/promptflow/_constants.py index 6ddf01d05998..5533100cba70 100644 --- a/src/promptflow/promptflow/_constants.py +++ b/src/promptflow/promptflow/_constants.py @@ -18,6 +18,7 @@ ERROR_RESPONSE_COMPONENT_NAME = "promptflow" EXTENSION_UA = "prompt-flow-extension" LANGUAGE_KEY = "language" +BONDED_USER_AGENT_KEY = "bonded_user_agent" # Tool meta info ICON_DARK = "icon_dark" diff --git a/src/promptflow/promptflow/_sdk/_pf_client.py b/src/promptflow/promptflow/_sdk/_pf_client.py index c6f99a0ef6e8..5951d08fa694 100644 --- a/src/promptflow/promptflow/_sdk/_pf_client.py +++ b/src/promptflow/promptflow/_sdk/_pf_client.py @@ -6,12 +6,14 @@ from pathlib import Path from typing import Any, Dict, List, Union +from .._constants import BONDED_USER_AGENT_KEY from .._utils.logger_utils import get_cli_sdk_logger +from ..exceptions import ErrorTarget, UserErrorException from ._configuration import Configuration -from ._constants import MAX_SHOW_DETAILS_RESULTS +from ._constants import MAX_SHOW_DETAILS_RESULTS, ConnectionProvider from ._load_functions import load_flow from ._user_agent import USER_AGENT -from ._utils import ClientUserAgentUtil, get_connection_operation, setup_user_agent_to_operation_context +from ._utils import ClientUserAgentUtil, setup_user_agent_to_operation_context from .entities import Run from .entities._eager_flow import EagerFlow from .operations import RunOperations @@ -34,20 +36,26 @@ class PFClient: def __init__(self, **kwargs): logger.debug("PFClient init with kwargs: %s", kwargs) - self._runs = RunOperations(self) + # when this is set, telemetry from this client will use this user agent instead of the one from OperationContext + if isinstance(kwargs.get(BONDED_USER_AGENT_KEY), str): + self._bonded_user_agent = kwargs[BONDED_USER_AGENT_KEY] self._connection_provider = kwargs.pop("connection_provider", None) self._config = kwargs.get("config", None) or {} # The credential is used as an option to override # DefaultAzureCredential when using workspace connection provider self._credential = kwargs.get("credential", None) - # Lazy init to avoid azure credential requires too early + + # bonded_user_agent will be applied to all TelemetryMixin operations + self._runs = RunOperations(self, bonded_user_agent=self._bonded_user_agent) + self._flows = FlowOperations(client=self, bonded_user_agent=self._bonded_user_agent) + self._experiments = ExperimentOperations(self, bonded_user_agent=self._bonded_user_agent) + # Lazy init to avoid azure credential requires too early; also need to apply bonded_user_agent self._connections = None - self._flows = FlowOperations(client=self) + self._tools = ToolOperations() # add user agent from kwargs if any if isinstance(kwargs.get("user_agent"), str): ClientUserAgentUtil.append_user_agent(kwargs["user_agent"]) - self._experiments = ExperimentOperations(self) self._traces = TraceOperations() setup_user_agent_to_operation_context(USER_AGENT) @@ -243,9 +251,47 @@ def connections(self) -> ConnectionOperations: """Connection operations that can manage connections.""" if not self._connections: self._ensure_connection_provider() - self._connections = get_connection_operation(self._connection_provider, self._credential) + self._connections = PFClient._build_connection_operation( + self._connection_provider, + self._credential, + bonded_user_agent=self._bonded_user_agent, + ) return self._connections + @staticmethod + def _build_connection_operation(connection_provider: str, credential=None, user_agent: str = None, **kwargs): + """ + Build a ConnectionOperation object based on connection provider. + + :param connection_provider: Connection provider, e.g. local, azureml, azureml://subscriptions..., etc. + :type connection_provider: str + :param credential: Credential when remote provider, default to chained credential DefaultAzureCredential. + :type credential: object + :param user_agent: User Agent + :type user_agent: str + """ + if connection_provider == ConnectionProvider.LOCAL.value: + from promptflow._sdk.operations._connection_operations import ConnectionOperations + + logger.debug("PFClient using local connection operations.") + connection_operation = ConnectionOperations(**kwargs) + elif connection_provider.startswith(ConnectionProvider.AZUREML.value): + from promptflow._sdk.operations._local_azure_connection_operations import LocalAzureConnectionOperations + + logger.debug(f"PFClient using local azure connection operations with credential {credential}.") + if user_agent is None: + connection_operation = LocalAzureConnectionOperations(connection_provider, credential=credential) + else: + connection_operation = LocalAzureConnectionOperations(connection_provider, user_agent=user_agent) + else: + error = ValueError(f"Unsupported connection provider: {connection_provider}") + raise UserErrorException( + target=ErrorTarget.CONTROL_PLANE_SDK, + message=str(error), + error=error, + ) + return connection_operation + @property def flows(self) -> FlowOperations: """Operations on the flow that can manage flows.""" diff --git a/src/promptflow/promptflow/_sdk/_service/utils/utils.py b/src/promptflow/promptflow/_sdk/_service/utils/utils.py index 96be9a087539..814628f60750 100644 --- a/src/promptflow/promptflow/_sdk/_service/utils/utils.py +++ b/src/promptflow/promptflow/_sdk/_service/utils/utils.py @@ -250,13 +250,9 @@ def get_client_from_request(*, connection_provider=None) -> "PFClient": user_agent = build_pfs_user_agent() if connection_provider: - pf_client = PFClient(connection_provider=connection_provider) + pf_client = PFClient(connection_provider=connection_provider, bonded_user_agent=user_agent) else: - pf_client = PFClient() - # DO NOT pass in user agent directly to PFClient, as it will impact the global OperationContext. - pf_client.connections._bond_user_agent(user_agent) - pf_client.runs._bond_user_agent(user_agent) - pf_client.flows._bond_user_agent(user_agent) + pf_client = PFClient(bonded_user_agent=user_agent) return pf_client diff --git a/src/promptflow/promptflow/_sdk/_telemetry/telemetry.py b/src/promptflow/promptflow/_sdk/_telemetry/telemetry.py index 5eff8967892d..d0192398ee00 100644 --- a/src/promptflow/promptflow/_sdk/_telemetry/telemetry.py +++ b/src/promptflow/promptflow/_sdk/_telemetry/telemetry.py @@ -4,6 +4,7 @@ import logging from typing import Optional +from promptflow._constants import BONDED_USER_AGENT_KEY from promptflow._sdk._configuration import Configuration PROMPTFLOW_LOGGER_NAMESPACE = "promptflow._sdk._telemetry" @@ -14,7 +15,7 @@ def __init__(self, **kwargs): # Need to call init for potential parent, otherwise it won't be initialized. super().__init__(**kwargs) - self._user_agent = kwargs.get("user_agent", None) + self._bonded_user_agent = kwargs.get(BONDED_USER_AGENT_KEY, None) def _get_telemetry_values(self, *args, **kwargs): # pylint: disable=unused-argument """Return the telemetry values of object. @@ -29,15 +30,7 @@ def _get_bonded_user_agent(self) -> Optional[str]: This user agent will be used in telemetry if specified instead of user agent from OperationContext. """ - return self._user_agent - - def _bond_user_agent(self, user_agent: str): - """Bond a user agent to the object. - - :param user_agent: The user agent to bond. - :type user_agent: str - """ - self._user_agent = user_agent + return self._bonded_user_agent class WorkspaceTelemetryMixin(TelemetryMixin): diff --git a/src/promptflow/promptflow/_sdk/_utils.py b/src/promptflow/promptflow/_sdk/_utils.py index f1b6f6f6a5c2..e8ef21030bcf 100644 --- a/src/promptflow/promptflow/_sdk/_utils.py +++ b/src/promptflow/promptflow/_sdk/_utils.py @@ -56,7 +56,6 @@ VARIANTS, AzureMLWorkspaceTriad, CommonYamlFields, - ConnectionProvider, ) from promptflow._sdk._errors import ( DecryptConnectionError, @@ -1045,41 +1044,6 @@ def parse_remote_flow_pattern(flow: object) -> str: return flow_name -def get_connection_operation(connection_provider: str, credential=None, user_agent: str = None): - """ - Get connection operation based on connection provider. - This function will be called by PFClient, so please do not refer to PFClient in this function. - - :param connection_provider: Connection provider, e.g. local, azureml, azureml://subscriptions..., etc. - :type connection_provider: str - :param credential: Credential when remote provider, default to chained credential DefaultAzureCredential. - :type credential: object - :param user_agent: User Agent - :type user_agent: str - """ - if connection_provider == ConnectionProvider.LOCAL.value: - from promptflow._sdk.operations._connection_operations import ConnectionOperations - - logger.debug("PFClient using local connection operations.") - connection_operation = ConnectionOperations() - elif connection_provider.startswith(ConnectionProvider.AZUREML.value): - from promptflow._sdk.operations._local_azure_connection_operations import LocalAzureConnectionOperations - - logger.debug(f"PFClient using local azure connection operations with credential {credential}.") - if user_agent is None: - connection_operation = LocalAzureConnectionOperations(connection_provider, credential=credential) - else: - connection_operation = LocalAzureConnectionOperations(connection_provider, user_agent=user_agent) - else: - error = ValueError(f"Unsupported connection provider: {connection_provider}") - raise UserErrorException( - target=ErrorTarget.CONTROL_PLANE_SDK, - message=str(error), - error=error, - ) - return connection_operation - - # extract open read/write as partial to centralize the encoding read_open = partial(open, mode="r", encoding=DEFAULT_ENCODING) write_open = partial(open, mode="w", encoding=DEFAULT_ENCODING) diff --git a/src/promptflow/promptflow/_sdk/operations/_flow_operations.py b/src/promptflow/promptflow/_sdk/operations/_flow_operations.py index 64a850192deb..d816cf945afd 100644 --- a/src/promptflow/promptflow/_sdk/operations/_flow_operations.py +++ b/src/promptflow/promptflow/_sdk/operations/_flow_operations.py @@ -48,9 +48,9 @@ class FlowOperations(TelemetryMixin): """FlowOperations.""" - def __init__(self, client): + def __init__(self, client, **kwargs): + super().__init__(**kwargs) self._client = client - super().__init__() @monitor_operation(activity_name="pf.flows.test", activity_type=ActivityType.PUBLICAPI) def test(