Skip to content

Commit

Permalink
feat: allow using msi port file location via environment variable
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotzh committed Mar 11, 2024
1 parent e8c7aa9 commit 668685d
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 26 deletions.
3 changes: 3 additions & 0 deletions src/promptflow/promptflow/_core/operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def append_user_agent(self, user_agent: str):
user_agent (str): The user agent information to append.
"""
if OperationContext.USER_AGENT_KEY in self:
# TODO: this judgement can be wrong when an user agent is a substring of another,
# e.g. "Mozilla/5.0" and "Mozilla/5.0 (Windows NT 10.0; Win64; x64)"
# however, changing this code may impact existing logic, so won't change it now
if user_agent not in self.user_agent:
self.user_agent = f"{self.user_agent.strip()} {user_agent.strip()}"
else:
Expand Down
8 changes: 2 additions & 6 deletions src/promptflow/promptflow/_sdk/_service/apis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import promptflow._sdk.schemas._connection as connection
from promptflow._sdk._configuration import Configuration
from promptflow._sdk._service import Namespace, Resource, fields
from promptflow._sdk._service.utils.utils import build_pfs_user_agent, local_user_only, make_response_no_content
from promptflow._sdk._service.utils.utils import get_client_from_request, local_user_only, make_response_no_content
from promptflow._sdk.entities._connection import _Connection

api = Namespace("Connections", description="Connections Management")
Expand Down Expand Up @@ -66,14 +66,10 @@ def validate_working_directory(value):


def _get_connection_operation(working_directory=None):
from promptflow._sdk._pf_client import PFClient

connection_provider = Configuration().get_connection_provider(path=working_directory)
# get_connection_operation is a shared function, so we build user agent based on request first and
# then pass it to the function
connection_operation = PFClient(
connection_provider=connection_provider, user_agent=build_pfs_user_agent()
).connections
connection_operation = get_client_from_request(connection_provider=connection_provider).connections
return connection_operation


Expand Down
28 changes: 22 additions & 6 deletions src/promptflow/promptflow/_sdk/_service/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,32 @@ def __post_init__(self, exception, status_code):


def build_pfs_user_agent():
extra_agent = f"local_pfs/{VERSION}"
if request.user_agent.string:
return f"{request.user_agent.string} {extra_agent}"
return extra_agent
user_agent = request.user_agent.string
extra_user_agent = f"local_pfs/{VERSION}"
if user_agent:
return f"{user_agent} {extra_user_agent}"
return extra_user_agent


def get_client_from_request() -> "PFClient":
def get_client_from_request(*, connection_provider=None) -> "PFClient":
"""
Build a PFClient instance based on current request in local PFS.
User agent may be different for each request.
"""
from promptflow._sdk._pf_client import PFClient

return PFClient(user_agent=build_pfs_user_agent())
user_agent = build_pfs_user_agent()

if connection_provider:
pf_client = PFClient(connection_provider=connection_provider)
else:
pf_client = PFClient()
# DO NOT pass in user agent directly to PFClient, as it will impact the global OperationContext.
pf_client.connections._bond_user_agent(user_agent)
pf_client.runs._bond_user_agent(user_agent)
pf_client.flows._bond_user_agent(user_agent)
return pf_client


def is_run_from_built_binary():
Expand Down
32 changes: 23 additions & 9 deletions src/promptflow/promptflow/_sdk/_telemetry/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def log_activity(
activity_name,
activity_type=ActivityType.INTERNALCALL,
custom_dimensions=None,
user_agent=None,
):
"""Log an activity.
Expand All @@ -121,12 +122,14 @@ def log_activity(
:type activity_type: str
:param custom_dimensions: The custom properties of the activity.
:type custom_dimensions: dict
:param user_agent: Specify user agent. If not specified, the user agent will be got from OperationContext.
:type user_agent: str
:return: None
"""
if not custom_dimensions:
custom_dimensions = {}

user_agent = ClientUserAgentUtil.get_user_agent()
user_agent = user_agent or ClientUserAgentUtil.get_user_agent()
request_id = request_id_context.get()
if not request_id:
# public function call
Expand Down Expand Up @@ -179,15 +182,19 @@ def log_activity(
raise exception


def extract_telemetry_info(self):
"""Extract pf telemetry info from given telemetry mix-in instance."""
result = {}
def extract_telemetry_info(telemetry_mixin):
"""Extract pf telemetry info from given telemetry mix-in instance.
:param telemetry_mixin: telemetry mix-in instance.
:type telemetry_mixin: TelemetryMixin
:return: custom dimensions and user agent in telemetry.
:rtype: Tuple[Dict, Optional[str]]
"""
try:
if isinstance(self, TelemetryMixin):
return self._get_telemetry_values()
if isinstance(telemetry_mixin, TelemetryMixin):
return telemetry_mixin._get_telemetry_values(), telemetry_mixin._get_bonded_user_agent()
except Exception:
pass
return result
return {}, None


def update_activity_name(activity_name, kwargs=None, args=None):
Expand Down Expand Up @@ -233,10 +240,17 @@ def wrapper(self, *args, **kwargs):

logger = get_telemetry_logger()

custom_dimensions.update(extract_telemetry_info(self))
extra_custom_dimensions, user_agent = extract_telemetry_info(self)
custom_dimensions.update(extra_custom_dimensions)
# update activity name according to kwargs.
_activity_name = update_activity_name(activity_name, kwargs=kwargs)
with log_activity(logger, _activity_name, activity_type, custom_dimensions):
with log_activity(
logger=logger,
activity_name=_activity_name,
activity_type=activity_type,
custom_dimensions=custom_dimensions,
user_agent=user_agent,
):
if _activity_name in HINT_ACTIVITY_NAME:
hint_for_update()
# set check_latest_version as deamon thread to avoid blocking main thread
Expand Down
19 changes: 19 additions & 0 deletions src/promptflow/promptflow/_sdk/_telemetry/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
import logging
from typing import Optional

from promptflow._sdk._configuration import Configuration

Expand All @@ -13,6 +14,8 @@ def __init__(self, **kwargs):
# Need to call init for potential parent, otherwise it won't be initialized.
super().__init__(**kwargs)

self._user_agent = kwargs.get("user_agent", None)

def _get_telemetry_values(self, *args, **kwargs): # pylint: disable=unused-argument
"""Return the telemetry values of object.
Expand All @@ -21,6 +24,21 @@ def _get_telemetry_values(self, *args, **kwargs): # pylint: disable=unused-argu
"""
return {}

def _get_bonded_user_agent(self) -> Optional[str]:
"""If we have a bonded user agent (passed in via the constructor or _bond_user_agent), return it.
This user agent will be used in telemetry if specified instead of user agent from OperationContext.
"""
return self._user_agent

def _bond_user_agent(self, user_agent: str):
"""Bond a user agent to the object.
:param user_agent: The user agent to bond.
:type user_agent: str
"""
self._user_agent = user_agent


class WorkspaceTelemetryMixin(TelemetryMixin):
def __init__(self, subscription_id, resource_group_name, workspace_name, **kwargs):
Expand All @@ -37,6 +55,7 @@ def _get_telemetry_values(self, *args, **kwargs): # pylint: disable=unused-argu
:rtype: Dict
"""
return {
**super()._get_telemetry_values(),
"subscription_id": self._telemetry_subscription_id,
"resource_group_name": self._telemetry_resource_group_name,
"workspace_name": self._telemetry_workspace_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self):
self._from_cli = False

def _get_telemetry_values(self, *args, **kwargs):
return {"request_id": self._request_id, "from_cli": self._from_cli}
return {**super()._get_telemetry_values(), "request_id": self._request_id, "from_cli": self._from_cli}

def _set_from_cli_for_telemetry(self):
self._from_cli = True
Expand Down
16 changes: 16 additions & 0 deletions src/promptflow/tests/sdk_pfs_test/e2etests/test_connection_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ def test_list_connections(self, pf_client: PFClient, pfs_op: PFSOperations) -> N

assert len(connections) >= 1

def test_list_connections_with_different_user_agent(self, pf_client: PFClient, pfs_op: PFSOperations) -> None:
create_custom_connection(pf_client)
base_user_agent = ["local_pfs/0.0.1"]
for _, extra_user_agent in enumerate(
[
["another_test_user_agent/0.0.1"],
["test_user_agent/0.0.1"],
["another_test_user_agent/0.0.1"],
["test_user_agent/0.0.1"],
]
):
with check_activity_end_telemetry(
activity_name="pf.connections.list", user_agent=base_user_agent + extra_user_agent
):
pfs_op.list_connections(user_agent=extra_user_agent)

def test_get_connection(self, pf_client: PFClient, pfs_op: PFSOperations) -> None:
name = create_custom_connection(pf_client)
with check_activity_end_telemetry(activity_name="pf.connections.get"):
Expand Down
18 changes: 14 additions & 4 deletions src/promptflow/tests/sdk_pfs_test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@ def check_activity_end_telemetry(
"first_call": True,
"activity_type": "PublicApi",
"completion_status": "Success",
"user_agent": f"promptflow-sdk/0.0.1 Werkzeug/{werkzeug.__version__} local_pfs/0.0.1",
"user_agent": [f"Werkzeug/{werkzeug.__version__}", "local_pfs/0.0.1"],
}
for i, expected_activity in enumerate(expected_activities):
temp = default_expected_call.copy()
temp.update(expected_activity)
expected_activity = temp
for key, expected_value in expected_activity.items():
value = actual_activities[i][key]
if isinstance(expected_value, list):
value = list(sorted(value.split(" ")))
expected_value = list(sorted(expected_value))
assert (
value == expected_value
), f"{key} mismatch in {i+1}th call: expect {expected_value} but got {value}"
Expand All @@ -54,7 +57,12 @@ class PFSOperations:
def __init__(self, client: FlaskClient):
self._client = client

def remote_user_header(self):
def remote_user_header(self, user_agent=None):
if user_agent:
return {
"X-Remote-User": getpass.getuser(),
"User-Agent": user_agent,
}
return {"X-Remote-User": getpass.getuser()}

def heartbeat(self):
Expand All @@ -67,8 +75,10 @@ def connection_operation_with_invalid_user(self, status_code=None):
assert status_code == response.status_code, response.text
return response

def list_connections(self, status_code=None):
response = self._client.get(f"{self.CONNECTION_URL_PREFIX}/", headers=self.remote_user_header())
def list_connections(self, status_code=None, user_agent=None):
response = self._client.get(
f"{self.CONNECTION_URL_PREFIX}/", headers=self.remote_user_header(user_agent=user_agent)
)
if status_code:
assert status_code == response.status_code, response.text
return response
Expand Down

0 comments on commit 668685d

Please sign in to comment.