Skip to content

Commit

Permalink
refactor: initialize bonded user agent in constructor of PFClient
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotzh committed Mar 12, 2024
1 parent 668685d commit 9d307e0
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 61 deletions.
1 change: 1 addition & 0 deletions src/promptflow/promptflow/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ERROR_RESPONSE_COMPONENT_NAME = "promptflow"
EXTENSION_UA = "prompt-flow-extension"
LANGUAGE_KEY = "language"
BONDED_USER_AGENT_KEY = "bonded_user_agent"

# Tool meta info
ICON_DARK = "icon_dark"
Expand Down
60 changes: 53 additions & 7 deletions src/promptflow/promptflow/_sdk/_pf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from pathlib import Path
from typing import Any, Dict, List, Union

from .._constants import BONDED_USER_AGENT_KEY
from .._utils.logger_utils import get_cli_sdk_logger
from ..exceptions import ErrorTarget, UserErrorException
from ._configuration import Configuration
from ._constants import MAX_SHOW_DETAILS_RESULTS
from ._constants import MAX_SHOW_DETAILS_RESULTS, ConnectionProvider
from ._load_functions import load_flow
from ._user_agent import USER_AGENT
from ._utils import ClientUserAgentUtil, get_connection_operation, setup_user_agent_to_operation_context
from ._utils import ClientUserAgentUtil, setup_user_agent_to_operation_context
from .entities import Run
from .entities._eager_flow import EagerFlow
from .operations import RunOperations
Expand All @@ -34,20 +36,26 @@ class PFClient:

def __init__(self, **kwargs):
logger.debug("PFClient init with kwargs: %s", kwargs)
self._runs = RunOperations(self)
# when this is set, telemetry from this client will use this user agent instead of the one from OperationContext
if isinstance(kwargs.get(BONDED_USER_AGENT_KEY), str):
self._bonded_user_agent = kwargs[BONDED_USER_AGENT_KEY]
self._connection_provider = kwargs.pop("connection_provider", None)
self._config = kwargs.get("config", None) or {}
# The credential is used as an option to override
# DefaultAzureCredential when using workspace connection provider
self._credential = kwargs.get("credential", None)
# Lazy init to avoid azure credential requires too early

# bonded_user_agent will be applied to all TelemetryMixin operations
self._runs = RunOperations(self, bonded_user_agent=self._bonded_user_agent)
self._flows = FlowOperations(client=self, bonded_user_agent=self._bonded_user_agent)
self._experiments = ExperimentOperations(self, bonded_user_agent=self._bonded_user_agent)
# Lazy init to avoid azure credential requires too early; also need to apply bonded_user_agent
self._connections = None
self._flows = FlowOperations(client=self)

self._tools = ToolOperations()
# add user agent from kwargs if any
if isinstance(kwargs.get("user_agent"), str):
ClientUserAgentUtil.append_user_agent(kwargs["user_agent"])
self._experiments = ExperimentOperations(self)
self._traces = TraceOperations()
setup_user_agent_to_operation_context(USER_AGENT)

Expand Down Expand Up @@ -243,9 +251,47 @@ def connections(self) -> ConnectionOperations:
"""Connection operations that can manage connections."""
if not self._connections:
self._ensure_connection_provider()
self._connections = get_connection_operation(self._connection_provider, self._credential)
self._connections = PFClient._build_connection_operation(
self._connection_provider,
self._credential,
bonded_user_agent=self._bonded_user_agent,
)
return self._connections

@staticmethod
def _build_connection_operation(connection_provider: str, credential=None, user_agent: str = None, **kwargs):
"""
Build a ConnectionOperation object based on connection provider.
:param connection_provider: Connection provider, e.g. local, azureml, azureml://subscriptions..., etc.
:type connection_provider: str
:param credential: Credential when remote provider, default to chained credential DefaultAzureCredential.
:type credential: object
:param user_agent: User Agent
:type user_agent: str
"""
if connection_provider == ConnectionProvider.LOCAL.value:
from promptflow._sdk.operations._connection_operations import ConnectionOperations

logger.debug("PFClient using local connection operations.")
connection_operation = ConnectionOperations(**kwargs)
elif connection_provider.startswith(ConnectionProvider.AZUREML.value):
from promptflow._sdk.operations._local_azure_connection_operations import LocalAzureConnectionOperations

logger.debug(f"PFClient using local azure connection operations with credential {credential}.")
if user_agent is None:
connection_operation = LocalAzureConnectionOperations(connection_provider, credential=credential)
else:
connection_operation = LocalAzureConnectionOperations(connection_provider, user_agent=user_agent)
else:
error = ValueError(f"Unsupported connection provider: {connection_provider}")
raise UserErrorException(
target=ErrorTarget.CONTROL_PLANE_SDK,
message=str(error),
error=error,
)
return connection_operation

@property
def flows(self) -> FlowOperations:
"""Operations on the flow that can manage flows."""
Expand Down
8 changes: 2 additions & 6 deletions src/promptflow/promptflow/_sdk/_service/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,9 @@ def get_client_from_request(*, connection_provider=None) -> "PFClient":
user_agent = build_pfs_user_agent()

if connection_provider:
pf_client = PFClient(connection_provider=connection_provider)
pf_client = PFClient(connection_provider=connection_provider, bonded_user_agent=user_agent)
else:
pf_client = PFClient()
# DO NOT pass in user agent directly to PFClient, as it will impact the global OperationContext.
pf_client.connections._bond_user_agent(user_agent)
pf_client.runs._bond_user_agent(user_agent)
pf_client.flows._bond_user_agent(user_agent)
pf_client = PFClient(bonded_user_agent=user_agent)
return pf_client


Expand Down
13 changes: 3 additions & 10 deletions src/promptflow/promptflow/_sdk/_telemetry/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from typing import Optional

from promptflow._constants import BONDED_USER_AGENT_KEY
from promptflow._sdk._configuration import Configuration

PROMPTFLOW_LOGGER_NAMESPACE = "promptflow._sdk._telemetry"
Expand All @@ -14,7 +15,7 @@ def __init__(self, **kwargs):
# Need to call init for potential parent, otherwise it won't be initialized.
super().__init__(**kwargs)

self._user_agent = kwargs.get("user_agent", None)
self._bonded_user_agent = kwargs.get(BONDED_USER_AGENT_KEY, None)

def _get_telemetry_values(self, *args, **kwargs): # pylint: disable=unused-argument
"""Return the telemetry values of object.
Expand All @@ -29,15 +30,7 @@ def _get_bonded_user_agent(self) -> Optional[str]:
This user agent will be used in telemetry if specified instead of user agent from OperationContext.
"""
return self._user_agent

def _bond_user_agent(self, user_agent: str):
"""Bond a user agent to the object.
:param user_agent: The user agent to bond.
:type user_agent: str
"""
self._user_agent = user_agent
return self._bonded_user_agent


class WorkspaceTelemetryMixin(TelemetryMixin):
Expand Down
36 changes: 0 additions & 36 deletions src/promptflow/promptflow/_sdk/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
VARIANTS,
AzureMLWorkspaceTriad,
CommonYamlFields,
ConnectionProvider,
)
from promptflow._sdk._errors import (
DecryptConnectionError,
Expand Down Expand Up @@ -1045,41 +1044,6 @@ def parse_remote_flow_pattern(flow: object) -> str:
return flow_name


def get_connection_operation(connection_provider: str, credential=None, user_agent: str = None):
"""
Get connection operation based on connection provider.
This function will be called by PFClient, so please do not refer to PFClient in this function.
:param connection_provider: Connection provider, e.g. local, azureml, azureml://subscriptions..., etc.
:type connection_provider: str
:param credential: Credential when remote provider, default to chained credential DefaultAzureCredential.
:type credential: object
:param user_agent: User Agent
:type user_agent: str
"""
if connection_provider == ConnectionProvider.LOCAL.value:
from promptflow._sdk.operations._connection_operations import ConnectionOperations

logger.debug("PFClient using local connection operations.")
connection_operation = ConnectionOperations()
elif connection_provider.startswith(ConnectionProvider.AZUREML.value):
from promptflow._sdk.operations._local_azure_connection_operations import LocalAzureConnectionOperations

logger.debug(f"PFClient using local azure connection operations with credential {credential}.")
if user_agent is None:
connection_operation = LocalAzureConnectionOperations(connection_provider, credential=credential)
else:
connection_operation = LocalAzureConnectionOperations(connection_provider, user_agent=user_agent)
else:
error = ValueError(f"Unsupported connection provider: {connection_provider}")
raise UserErrorException(
target=ErrorTarget.CONTROL_PLANE_SDK,
message=str(error),
error=error,
)
return connection_operation


# extract open read/write as partial to centralize the encoding
read_open = partial(open, mode="r", encoding=DEFAULT_ENCODING)
write_open = partial(open, mode="w", encoding=DEFAULT_ENCODING)
Expand Down
4 changes: 2 additions & 2 deletions src/promptflow/promptflow/_sdk/operations/_flow_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
class FlowOperations(TelemetryMixin):
"""FlowOperations."""

def __init__(self, client):
def __init__(self, client, **kwargs):
super().__init__(**kwargs)
self._client = client
super().__init__()

@monitor_operation(activity_name="pf.flows.test", activity_type=ActivityType.PUBLICAPI)
def test(
Expand Down

0 comments on commit 9d307e0

Please sign in to comment.