Skip to content

Commit

Permalink
Improve created_by info for local to cloud trace (#2232)
Browse files Browse the repository at this point in the history
# Description

1. Add name in created_by info (For UI show)
2. Call method to get created_by in each request instead of starting
service
As info required by local to cloud trace, don't depend on Configuration
file is clearer.


![image](https://github.com/microsoft/promptflow/assets/17527303/d8957e46-9848-40c3-8f4b-8d1fb1fc3816)


# All Promptflow Contribution checklist:
- [X] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [X] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [X] Title of the pull request is clear and informative.
- [X] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Co-authored-by: robbenwang <[email protected]>
  • Loading branch information
huaiyan and robbenwang authored Mar 12, 2024
1 parent bf96c0b commit 840fa04
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 41 deletions.
35 changes: 26 additions & 9 deletions src/promptflow/promptflow/_sdk/_service/apis/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import traceback
from datetime import datetime
from typing import Callable

from flask import request
from google.protobuf.json_format import MessageToJson
Expand All @@ -28,7 +29,15 @@
from promptflow._utils.thread_utils import ThreadWithContextVars


def trace_collector(logger: logging.Logger):
def trace_collector(get_created_by_info_with_cache: Callable, logger: logging.Logger):
"""
This function is target to be reused in other places, so pass in get_created_by_info_with_cache and logger to avoid
app related dependencies.
Args:
get_created_by_info_with_cache (Callable): A function that retrieves information about the creator of the trace.
logger (logging.Logger): The logger object used for logging.
"""
content_type = request.headers.get("Content-Type")
# binary protobuf encoding
if "application/x-protobuf" in content_type:
Expand All @@ -55,15 +64,17 @@ def trace_collector(logger: logging.Logger):
all_spans.append(span)

# Create a new thread to write trace to cosmosdb to avoid blocking the main thread
ThreadWithContextVars(target=_try_write_trace_to_cosmosdb, args=(all_spans, logger)).start()
ThreadWithContextVars(
target=_try_write_trace_to_cosmosdb, args=(all_spans, get_created_by_info_with_cache, logger)
).start()
return "Traces received", 200

# JSON protobuf encoding
elif "application/json" in content_type:
raise NotImplementedError


def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger):
def _try_write_trace_to_cosmosdb(all_spans, get_created_by_info_with_cache: Callable, logger: logging.Logger):
if not all_spans:
return
try:
Expand All @@ -78,31 +89,37 @@ def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger):

logger.info(f"Start writing trace to cosmosdb, total spans count: {len(all_spans)}.")
start_time = datetime.now()
from promptflow._sdk._service.app import CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE
from promptflow.azure._storage.cosmosdb.client import get_client
from promptflow.azure._storage.cosmosdb.span import Span as SpanCosmosDB
from promptflow.azure._storage.cosmosdb.summary import Summary

# Load span and summary clients first time may slow.
# So, we load 2 client in parallel for warm up.
span_thread = ThreadWithContextVars(
span_client_thread = ThreadWithContextVars(
target=get_client, args=(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name)
)
span_thread.start()
span_client_thread.start()

# Load created_by info first time may slow. So, we load it in parallel for warm up.
created_by_thread = ThreadWithContextVars(target=get_created_by_info_with_cache)
created_by_thread.start()

get_client(CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name)

span_thread.join()
span_client_thread.join()
created_by_thread.join()

created_by = get_created_by_info_with_cache()

for span in all_spans:
span_client = get_client(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name)
result = SpanCosmosDB(span, CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE).persist(span_client)
result = SpanCosmosDB(span, created_by).persist(span_client)
# None means the span already exists, then we don't need to persist the summary also.
if result is not None:
line_summary_client = get_client(
CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name
)
Summary(span, CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE, logger).persist(line_summary_client)
Summary(span, created_by, logger).persist(line_summary_client)
logger.info(
(
f"Finish writing trace to cosmosdb, total spans count: {len(all_spans)}."
Expand Down
71 changes: 39 additions & 32 deletions src/promptflow/promptflow/_sdk/_service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# ---------------------------------------------------------
import logging
import sys
import threading
import time
from datetime import datetime, timedelta
from logging.handlers import RotatingFileHandler
Expand Down Expand Up @@ -42,9 +43,6 @@ def heartbeat():
return jsonify(response)


CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE = {}


def create_app():
app = Flask(__name__)

Expand All @@ -54,7 +52,9 @@ def create_app():
CORS(app)

app.add_url_rule("/heartbeat", view_func=heartbeat)
app.add_url_rule("/v1/traces", view_func=lambda: trace_collector(app.logger), methods=["POST"])
app.add_url_rule(
"/v1/traces", view_func=lambda: trace_collector(get_created_by_info_with_cache, app.logger), methods=["POST"]
)
with app.app_context():
api_v1 = Blueprint("Prompt Flow Service", __name__, url_prefix="/v1.0")

Expand Down Expand Up @@ -86,33 +86,6 @@ def create_app():
# Set app logger to the only one RotatingFileHandler to avoid duplicate logs
app.logger.handlers = [handler]

def initialize_created_by_info():
from promptflow._sdk._configuration import Configuration
from promptflow._sdk._utils import extract_workspace_triad_from_trace_provider

trace_provider = Configuration.get_instance().get_trace_provider()
if trace_provider is None or extract_workspace_triad_from_trace_provider(trace_provider) is None:
return
try:
import jwt
from azure.identity import DefaultAzureCredential

from promptflow.azure._utils.general import get_arm_token

default_credential = DefaultAzureCredential()

token = get_arm_token(credential=default_credential)
decoded_token = jwt.decode(token, options={"verify_signature": False})
user_object_id, user_tenant_id = decoded_token["oid"], decoded_token["tid"]
CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE.update(
{
"object_id": user_object_id,
"tenant_id": user_tenant_id,
}
)
except Exception as e:
current_app.logger.error(f"Failed to get created_by info, ignore it: {e}")

# Basic error handler
@api.errorhandler(Exception)
def handle_exception(e):
Expand Down Expand Up @@ -167,8 +140,42 @@ def monitor_request():
kill_exist_service(port)
break

initialize_created_by_info()
if not sys.executable.endswith("pfcli.exe"):
monitor_thread = ThreadWithContextVars(target=monitor_request, daemon=True)
monitor_thread.start()
return app, api


created_by_for_local_to_cloud_trace = {}
created_by_for_local_to_cloud_trace_lock = threading.Lock()


def get_created_by_info_with_cache():
if len(created_by_for_local_to_cloud_trace) > 0:
return created_by_for_local_to_cloud_trace
with created_by_for_local_to_cloud_trace_lock:
if len(created_by_for_local_to_cloud_trace) > 0:
return created_by_for_local_to_cloud_trace
try:
# The total time of collecting info is about 3s.
import jwt
from azure.identity import DefaultAzureCredential

from promptflow.azure._utils.general import get_arm_token

default_credential = DefaultAzureCredential()

token = get_arm_token(credential=default_credential)
decoded_token = jwt.decode(token, options={"verify_signature": False})
created_by_for_local_to_cloud_trace.update(
{
"object_id": decoded_token["oid"],
"tenant_id": decoded_token["tid"],
# Use appid as fallback for service principal scenario.
"name": decoded_token.get("name", decoded_token.get("appid", "")),
}
)
except Exception as e:
# This function is only target to be used in Flask app.
current_app.logger.error(f"Failed to get created_by info, ignore it: {e}")
return created_by_for_local_to_cloud_trace

0 comments on commit 840fa04

Please sign in to comment.