Skip to content

Commit

Permalink
fix: reset user agent for each requests in local pfs
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotzh committed Mar 11, 2024
1 parent f4215a4 commit 27ea2f4
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 28 deletions.
9 changes: 7 additions & 2 deletions src/promptflow/promptflow/_core/operation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,13 @@ def append_user_agent(self, user_agent: str):
user_agent (str): The user agent information to append.
"""
if OperationContext.USER_AGENT_KEY in self:
if user_agent not in self.user_agent:
self.user_agent = f"{self.user_agent.strip()} {user_agent.strip()}"
user_agent_seperator = " "
new_user_agents = user_agent.strip().split(user_agent_seperator)
current_user_agents = self.user_agent.strip().split(user_agent_seperator)
for new_user_agent in new_user_agents:
if new_user_agent not in current_user_agents:
current_user_agents.append(new_user_agent)
self.user_agent = user_agent_seperator.join(current_user_agents)
else:
self.user_agent = user_agent

Expand Down
4 changes: 2 additions & 2 deletions src/promptflow/promptflow/_sdk/_service/apis/line_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from promptflow._sdk._constants import PFS_MODEL_DATETIME_FORMAT, CumulativeTokenCountFieldName, LineRunFieldName
from promptflow._sdk._service import Namespace, Resource
from promptflow._sdk._service.utils.utils import get_client_from_request
from promptflow._sdk._service.utils.utils import get_client_based_on_pfs_request
from promptflow._sdk.entities._trace import LineRun

api = Namespace("LineRuns", description="Line runs management")
Expand Down Expand Up @@ -81,7 +81,7 @@ class LineRuns(Resource):
def get(self):
from promptflow import PFClient

client: PFClient = get_client_from_request()
client: PFClient = get_client_based_on_pfs_request()
args = ListLineRunParser.from_request()
line_runs: typing.List[LineRun] = client._traces.list_line_runs(
session_id=args.session_id,
Expand Down
32 changes: 18 additions & 14 deletions src/promptflow/promptflow/_sdk/_service/apis/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from promptflow._sdk._constants import FlowRunProperties, get_list_view_type
from promptflow._sdk._errors import RunNotFoundError
from promptflow._sdk._service import Namespace, Resource, fields
from promptflow._sdk._service.utils.utils import build_pfs_user_agent, get_client_from_request, make_response_no_content
from promptflow._sdk._service.utils.utils import (
build_pfs_user_agent,
get_client_based_on_pfs_request,
make_response_no_content,
)
from promptflow._sdk.entities import Run as RunEntity
from promptflow._sdk.operations._local_storage_operations import LocalStorageOperations
from promptflow._utils.yaml_utils import dump_yaml
Expand Down Expand Up @@ -52,7 +56,7 @@ def get(self):
max_results = None
list_view_type = get_list_view_type(archived_only=archived_only, include_archived=include_archived)

runs = get_client_from_request().runs.list(max_results=max_results, list_view_type=list_view_type)
runs = get_client_based_on_pfs_request().runs.list(max_results=max_results, list_view_type=list_view_type)
runs_dict = [run._to_dict() for run in runs]
return jsonify(runs_dict)

Expand Down Expand Up @@ -88,7 +92,7 @@ def post(self):
stdout, _ = process.communicate()
if process.returncode == 0:
try:
run = get_client_from_request().runs._get(name=run_name)
run = get_client_based_on_pfs_request().runs._get(name=run_name)
return jsonify(run._to_dict())
except RunNotFoundError as e:
raise RunNotFoundError(
Expand All @@ -107,21 +111,21 @@ class Run(Resource):
def put(self, name: str):
args = update_run_parser.parse_args()
tags = json.loads(args.tags) if args.tags else None
run = get_client_from_request().runs.update(
run = get_client_based_on_pfs_request().runs.update(
name=name, display_name=args.display_name, description=args.description, tags=tags
)
return jsonify(run._to_dict())

@api.response(code=200, description="Get run info", model=dict_field)
@api.doc(description="Get run")
def get(self, name: str):
run = get_client_from_request().runs.get(name=name)
run = get_client_based_on_pfs_request().runs.get(name=name)
return jsonify(run._to_dict())

@api.response(code=204, description="Delete run", model=dict_field)
@api.doc(description="Delete run")
def delete(self, name: str):
get_client_from_request().runs.delete(name=name)
get_client_based_on_pfs_request().runs.delete(name=name)
return make_response_no_content()


Expand All @@ -130,7 +134,7 @@ class FlowChildRuns(Resource):
@api.response(code=200, description="Child runs", model=list_field)
@api.doc(description="Get child runs")
def get(self, name: str):
run = get_client_from_request().runs.get(name=name)
run = get_client_based_on_pfs_request().runs.get(name=name)
local_storage_op = LocalStorageOperations(run=run)
detail_dict = local_storage_op.load_detail()
return jsonify(detail_dict["flow_runs"])
Expand All @@ -141,7 +145,7 @@ class FlowNodeRuns(Resource):
@api.response(code=200, description="Node runs", model=list_field)
@api.doc(description="Get node runs info")
def get(self, name: str, node_name: str):
run = get_client_from_request().runs.get(name=name)
run = get_client_based_on_pfs_request().runs.get(name=name)
local_storage_op = LocalStorageOperations(run=run)
detail_dict = local_storage_op.load_detail()
node_runs = [item for item in detail_dict["node_runs"] if item["node"] == node_name]
Expand All @@ -153,7 +157,7 @@ class MetaData(Resource):
@api.doc(description="Get metadata of run")
@api.response(code=200, description="Run metadata", model=dict_field)
def get(self, name: str):
run = get_client_from_request().runs.get(name=name)
run = get_client_based_on_pfs_request().runs.get(name=name)
local_storage_op = LocalStorageOperations(run=run)
metadata = RunMetadata(
name=run.name,
Expand All @@ -175,7 +179,7 @@ class LogContent(Resource):
@api.doc(description="Get run log content")
@api.response(code=200, description="Log content", model=fields.String)
def get(self, name: str):
run = get_client_from_request().runs.get(name=name)
run = get_client_based_on_pfs_request().runs.get(name=name)
local_storage_op = LocalStorageOperations(run=run)
log_content = local_storage_op.logger.get_logs()
return make_response(log_content)
Expand All @@ -186,7 +190,7 @@ class Metrics(Resource):
@api.doc(description="Get run metrics")
@api.response(code=200, description="Run metrics", model=dict_field)
def get(self, name: str):
run = get_client_from_request().runs.get(name=name)
run = get_client_based_on_pfs_request().runs.get(name=name)
local_storage_op = LocalStorageOperations(run=run)
metrics = local_storage_op.load_metrics()
return jsonify(metrics)
Expand All @@ -201,7 +205,7 @@ def get(self, name: str):
with tempfile.TemporaryDirectory() as temp_dir:
from promptflow._sdk.operations import RunOperations

run_op: RunOperations = get_client_from_request().runs
run_op: RunOperations = get_client_based_on_pfs_request().runs
html_path = Path(temp_dir) / "visualize_run.html"
# visualize operation may accept name in string
run_op.visualize(name, html_path=html_path)
Expand All @@ -215,7 +219,7 @@ class ArchiveRun(Resource):
@api.doc(description="Archive run")
@api.response(code=200, description="Archived run", model=dict_field)
def get(self, name: str):
run = get_client_from_request().runs.archive(name=name)
run = get_client_based_on_pfs_request().runs.archive(name=name)
return jsonify(run._to_dict())


Expand All @@ -224,5 +228,5 @@ class RestoreRun(Resource):
@api.doc(description="Restore run")
@api.response(code=200, description="Restored run", model=dict_field)
def get(self, name: str):
run = get_client_from_request().runs.restore(name=name)
run = get_client_based_on_pfs_request().runs.restore(name=name)
return jsonify(run._to_dict())
4 changes: 2 additions & 2 deletions src/promptflow/promptflow/_sdk/_service/apis/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from promptflow._sdk._constants import PFS_MODEL_DATETIME_FORMAT
from promptflow._sdk._service import Namespace, Resource
from promptflow._sdk._service.utils.utils import get_client_from_request
from promptflow._sdk._service.utils.utils import get_client_based_on_pfs_request

api = Namespace("Spans", description="Spans Management")

Expand Down Expand Up @@ -108,7 +108,7 @@ class Spans(Resource):
def get(self):
from promptflow import PFClient

client: PFClient = get_client_from_request()
client: PFClient = get_client_based_on_pfs_request()
args = ListSpanParser.from_request()
spans = client._traces.list_spans(
session_id=args.session_id,
Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/_sdk/_service/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def validate_port(port, force_start):
port = get_port_from_config(create_if_not_exists=True)
validate_port(port, args.force)

if sys.executable.endswith("pfcli.exe"):
if sys.executable.endswith("pfcli.exe") or args.debug:
# For msi installer, use sdk api to start pfs since it's not supported to invoke waitress by cli directly
# after packaged by Pyinstaller.
app, _ = create_app()
Expand Down
17 changes: 14 additions & 3 deletions src/promptflow/promptflow/_sdk/_service/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
PF_SERVICE_PORT_FILE,
)
from promptflow._sdk._errors import ConnectionNotFoundError, RunNotFoundError
from promptflow._sdk._utils import get_promptflow_sdk_version, read_write_by_user
from promptflow._sdk._utils import ClientUserAgentUtil, get_promptflow_sdk_version, read_write_by_user
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow._utils.yaml_utils import dump_yaml, load_yaml
from promptflow._version import VERSION
Expand Down Expand Up @@ -231,13 +231,24 @@ def __post_init__(self, exception, status_code):


def build_pfs_user_agent():
# For local pfs:
# 1. the operation context won't be cleared after each request.
# 2. user agent can be different for each request.
# So, we need to reset user agent before each request.
# TODO: this may impact async requests as operation context may be changed by another request during one request.
# 1) the best practice should be passing user agent in kwargs of operations.
# 2) given local pfs is for authoring and user don't have multiple agents for now, we can wait for feedback.
ClientUserAgentUtil.pop_current_user_agent()

extra_agent = f"local_pfs/{VERSION}"
if request.user_agent.string:
return f"{request.user_agent.string} {extra_agent}"
return extra_agent


def get_client_from_request() -> "PFClient":
def get_client_based_on_pfs_request() -> "PFClient":
from promptflow._sdk._pf_client import PFClient

return PFClient(user_agent=build_pfs_user_agent())
pf_client = PFClient(user_agent=build_pfs_user_agent())
_ = ClientUserAgentUtil.get_user_agent()
return pf_client
7 changes: 7 additions & 0 deletions src/promptflow/promptflow/_sdk/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,13 @@ def get_user_agent(cls):
# directly get from context since client side won't need promptflow/xxx.
return context.get(OperationContext.USER_AGENT_KEY, "").strip()

@classmethod
def pop_current_user_agent(cls):
from promptflow._core.operation_context import OperationContext

context = cls._get_context()
return context.pop(OperationContext.USER_AGENT_KEY, None)

@classmethod
def append_user_agent(cls, user_agent: Optional[str]):
if not user_agent:
Expand Down
16 changes: 16 additions & 0 deletions src/promptflow/tests/sdk_pfs_test/e2etests/test_connection_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ def test_list_connections(self, pf_client: PFClient, pfs_op: PFSOperations) -> N

assert len(connections) >= 1

def test_list_connections_with_different_user_agent(self, pf_client: PFClient, pfs_op: PFSOperations) -> None:
create_custom_connection(pf_client)
base_user_agent = ["promptflow-sdk/0.0.1", "local_pfs/0.0.1"]
for _, extra_user_agent in enumerate(
[
["another_test_user_agent/0.0.1"],
["test_user_agent/0.0.1"],
["another_test_user_agent/0.0.1"],
["test_user_agent/0.0.1"],
]
):
with check_activity_end_telemetry(
activity_name="pf.connections.list", user_agent=base_user_agent + extra_user_agent
):
pfs_op.list_connections(user_agent=extra_user_agent)

def test_get_connection(self, pf_client: PFClient, pfs_op: PFSOperations) -> None:
name = create_custom_connection(pf_client)
with check_activity_end_telemetry(activity_name="pf.connections.get"):
Expand Down
18 changes: 14 additions & 4 deletions src/promptflow/tests/sdk_pfs_test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@ def check_activity_end_telemetry(
"first_call": True,
"activity_type": "PublicApi",
"completion_status": "Success",
"user_agent": f"promptflow-sdk/0.0.1 Werkzeug/{werkzeug.__version__} local_pfs/0.0.1",
"user_agent": ["promptflow-sdk/0.0.1", f"Werkzeug/{werkzeug.__version__}", "local_pfs/0.0.1"],
}
for i, expected_activity in enumerate(expected_activities):
temp = default_expected_call.copy()
temp.update(expected_activity)
expected_activity = temp
for key, expected_value in expected_activity.items():
value = actual_activities[i][key]
if isinstance(expected_value, list):
value = list(sorted(value.split(" ")))
expected_value = list(sorted(expected_value))
assert (
value == expected_value
), f"{key} mismatch in {i+1}th call: expect {expected_value} but got {value}"
Expand All @@ -54,7 +57,12 @@ class PFSOperations:
def __init__(self, client: FlaskClient):
self._client = client

def remote_user_header(self):
def remote_user_header(self, user_agent=None):
if user_agent:
return {
"X-Remote-User": getpass.getuser(),
"User-Agent": user_agent,
}
return {"X-Remote-User": getpass.getuser()}

def heartbeat(self):
Expand All @@ -67,8 +75,10 @@ def connection_operation_with_invalid_user(self, status_code=None):
assert status_code == response.status_code, response.text
return response

def list_connections(self, status_code=None):
response = self._client.get(f"{self.CONNECTION_URL_PREFIX}/", headers=self.remote_user_header())
def list_connections(self, status_code=None, user_agent=None):
response = self._client.get(
f"{self.CONNECTION_URL_PREFIX}/", headers=self.remote_user_header(user_agent=user_agent)
)
if status_code:
assert status_code == response.status_code, response.text
return response
Expand Down

0 comments on commit 27ea2f4

Please sign in to comment.