Skip to content

Commit

Permalink
fix: reset user agent for each requests in local pfs
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotzh committed Mar 11, 2024
1 parent f4215a4 commit 2657a08
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 9 deletions.
9 changes: 7 additions & 2 deletions src/promptflow/promptflow/_core/operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,13 @@ def append_user_agent(self, user_agent: str):
user_agent (str): The user agent information to append.
"""
if OperationContext.USER_AGENT_KEY in self:
if user_agent not in self.user_agent:
self.user_agent = f"{self.user_agent.strip()} {user_agent.strip()}"
user_agent_separator = " "
new_user_agents = user_agent.strip().split(user_agent_separator)
current_user_agents = self.user_agent.strip().split(user_agent_separator)
for new_user_agent in new_user_agents:
if new_user_agent not in current_user_agents:
current_user_agents.append(new_user_agent)
self.user_agent = user_agent_separator.join(current_user_agents)
else:
self.user_agent = user_agent

Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/_sdk/_service/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def validate_port(port, force_start):
port = get_port_from_config(create_if_not_exists=True)
validate_port(port, args.force)

if sys.executable.endswith("pfcli.exe"):
if sys.executable.endswith("pfcli.exe") or args.debug:
# For msi installer, use sdk api to start pfs since it's not supported to invoke waitress by cli directly
# after packaged by Pyinstaller.
app, _ = create_app()
Expand Down
20 changes: 18 additions & 2 deletions src/promptflow/promptflow/_sdk/_service/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
PF_SERVICE_PORT_FILE,
)
from promptflow._sdk._errors import ConnectionNotFoundError, RunNotFoundError
from promptflow._sdk._utils import get_promptflow_sdk_version, read_write_by_user
from promptflow._sdk._utils import ClientUserAgentUtil, get_promptflow_sdk_version, read_write_by_user
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow._utils.yaml_utils import dump_yaml, load_yaml
from promptflow._version import VERSION
Expand Down Expand Up @@ -231,13 +231,29 @@ def __post_init__(self, exception, status_code):


def build_pfs_user_agent():
# For local pfs:
# 1. the operation context won't be cleared after each request.
# 2. user agent can be different for each request.
# So, we need to reset user agent before each request.
# TODO: this may impact async requests as operation context may be changed by another request during one request.
# 1) the best practice should be passing user agent in kwargs of operations.
# 2) given local pfs is for authoring and user don't have multiple agents for now, we can wait for feedback.
ClientUserAgentUtil.pop_current_user_agent()

extra_agent = f"local_pfs/{VERSION}"
if request.user_agent.string:
return f"{request.user_agent.string} {extra_agent}"
return extra_agent


def get_client_from_request() -> "PFClient":
"""
Build a PFClient instance based on current request in local PFS.
User agent may be different for each request.
"""
from promptflow._sdk._pf_client import PFClient

return PFClient(user_agent=build_pfs_user_agent())
pf_client = PFClient(user_agent=build_pfs_user_agent())
_ = ClientUserAgentUtil.get_user_agent()
return pf_client
7 changes: 7 additions & 0 deletions src/promptflow/promptflow/_sdk/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,13 @@ def get_user_agent(cls):
# directly get from context since client side won't need promptflow/xxx.
return context.get(OperationContext.USER_AGENT_KEY, "").strip()

@classmethod
def pop_current_user_agent(cls):
from promptflow._core.operation_context import OperationContext

context = cls._get_context()
return context.pop(OperationContext.USER_AGENT_KEY, None)

@classmethod
def append_user_agent(cls, user_agent: Optional[str]):
if not user_agent:
Expand Down
16 changes: 16 additions & 0 deletions src/promptflow/tests/sdk_pfs_test/e2etests/test_connection_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ def test_list_connections(self, pf_client: PFClient, pfs_op: PFSOperations) -> N

assert len(connections) >= 1

def test_list_connections_with_different_user_agent(self, pf_client: PFClient, pfs_op: PFSOperations) -> None:
create_custom_connection(pf_client)
base_user_agent = ["promptflow-sdk/0.0.1", "local_pfs/0.0.1"]
for _, extra_user_agent in enumerate(
[
["another_test_user_agent/0.0.1"],
["test_user_agent/0.0.1"],
["another_test_user_agent/0.0.1"],
["test_user_agent/0.0.1"],
]
):
with check_activity_end_telemetry(
activity_name="pf.connections.list", user_agent=base_user_agent + extra_user_agent
):
pfs_op.list_connections(user_agent=extra_user_agent)

def test_get_connection(self, pf_client: PFClient, pfs_op: PFSOperations) -> None:
name = create_custom_connection(pf_client)
with check_activity_end_telemetry(activity_name="pf.connections.get"):
Expand Down
18 changes: 14 additions & 4 deletions src/promptflow/tests/sdk_pfs_test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@ def check_activity_end_telemetry(
"first_call": True,
"activity_type": "PublicApi",
"completion_status": "Success",
"user_agent": f"promptflow-sdk/0.0.1 Werkzeug/{werkzeug.__version__} local_pfs/0.0.1",
"user_agent": ["promptflow-sdk/0.0.1", f"Werkzeug/{werkzeug.__version__}", "local_pfs/0.0.1"],
}
for i, expected_activity in enumerate(expected_activities):
temp = default_expected_call.copy()
temp.update(expected_activity)
expected_activity = temp
for key, expected_value in expected_activity.items():
value = actual_activities[i][key]
if isinstance(expected_value, list):
value = list(sorted(value.split(" ")))
expected_value = list(sorted(expected_value))
assert (
value == expected_value
), f"{key} mismatch in {i+1}th call: expect {expected_value} but got {value}"
Expand All @@ -54,7 +57,12 @@ class PFSOperations:
def __init__(self, client: FlaskClient):
self._client = client

def remote_user_header(self):
def remote_user_header(self, user_agent=None):
if user_agent:
return {
"X-Remote-User": getpass.getuser(),
"User-Agent": user_agent,
}
return {"X-Remote-User": getpass.getuser()}

def heartbeat(self):
Expand All @@ -67,8 +75,10 @@ def connection_operation_with_invalid_user(self, status_code=None):
assert status_code == response.status_code, response.text
return response

def list_connections(self, status_code=None):
response = self._client.get(f"{self.CONNECTION_URL_PREFIX}/", headers=self.remote_user_header())
def list_connections(self, status_code=None, user_agent=None):
response = self._client.get(
f"{self.CONNECTION_URL_PREFIX}/", headers=self.remote_user_header(user_agent=user_agent)
)
if status_code:
assert status_code == response.status_code, response.text
return response
Expand Down

0 comments on commit 2657a08

Please sign in to comment.