Skip to content

Commit

Permalink
Fix comment, pass in callable to get created_by info for code reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
robbenwang committed Mar 11, 2024
1 parent 4701b91 commit cf9f6f8
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 47 deletions.
21 changes: 13 additions & 8 deletions src/promptflow/promptflow/_sdk/_service/apis/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import traceback
from datetime import datetime
from typing import Callable

from flask import request
from google.protobuf.json_format import MessageToJson
Expand All @@ -28,7 +29,9 @@
from promptflow._utils.thread_utils import ThreadWithContextVars


def trace_collector(logger: logging.Logger):
# Pass in the get_created_by_info_with_cache and logger to avoid app related dependency.
# To guarantee we can reuse this function in other places.
def trace_collector(get_created_by_info_with_cache: Callable, logger: logging.Logger):
content_type = request.headers.get("Content-Type")
# binary protobuf encoding
if "application/x-protobuf" in content_type:
Expand All @@ -55,15 +58,17 @@ def trace_collector(logger: logging.Logger):
all_spans.append(span)

# Create a new thread to write trace to cosmosdb to avoid blocking the main thread
ThreadWithContextVars(target=_try_write_trace_to_cosmosdb, args=(all_spans, logger)).start()
ThreadWithContextVars(
target=_try_write_trace_to_cosmosdb, args=(all_spans, get_created_by_info_with_cache, logger)
).start()
return "Traces received", 200

# JSON protobuf encoding
elif "application/json" in content_type:
raise NotImplementedError


def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger):
def _try_write_trace_to_cosmosdb(all_spans, get_created_by_info_with_cache: Callable, logger: logging.Logger):
if not all_spans:
return
try:
Expand All @@ -78,7 +83,6 @@ def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger):

logger.info(f"Start writing trace to cosmosdb, total spans count: {len(all_spans)}.")
start_time = datetime.now()
from promptflow._sdk._service.app import retrieve_created_by_info_with_cache
from promptflow.azure._storage.cosmosdb.client import get_client
from promptflow.azure._storage.cosmosdb.span import Span as SpanCosmosDB
from promptflow.azure._storage.cosmosdb.summary import Summary
Expand All @@ -92,10 +96,11 @@ def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger):

get_client(CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name)

# Retrieve created_by info, we already get it in advance as starting the service.
# But user may not run `az login` before running pfs service.
# So we need to call this function again to get the created_by info.
created_by = retrieve_created_by_info_with_cache()
# For local scenario, we already get created_by info in advance as starting the service.
# But if customer didn't run `az login`, we can't get it.
# So, we try to get and cache it again after getting CosmosDB token.
# Don't bother to run in new thread because in most cases, it will be cached.
created_by = get_created_by_info_with_cache() if get_created_by_info_with_cache else {}

span_thread.join()

Expand Down
82 changes: 43 additions & 39 deletions src/promptflow/promptflow/_sdk/_service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,43 +49,6 @@ def heartbeat():
return jsonify(response)


created_by_for_local_to_cloud_trace = {}
created_by_for_local_to_cloud_trace_lock = threading.Lock()


def retrieve_created_by_info_with_cache():
if len(created_by_for_local_to_cloud_trace) > 0:
return created_by_for_local_to_cloud_trace
trace_provider = Configuration.get_instance().get_trace_provider()
if trace_provider is None or extract_workspace_triad_from_trace_provider(trace_provider) is None:
return
with created_by_for_local_to_cloud_trace_lock:
if len(created_by_for_local_to_cloud_trace) > 0:
return created_by_for_local_to_cloud_trace
try:
# The total time of collecting info is about 3s.
# We may need to run below code more than once
# because user may not run `az login` before running pfs service.
import jwt
from azure.identity import DefaultAzureCredential

from promptflow.azure._utils.general import get_arm_token

default_credential = DefaultAzureCredential()

token = get_arm_token(credential=default_credential)
decoded_token = jwt.decode(token, options={"verify_signature": False})
created_by_for_local_to_cloud_trace.update(
{
"object_id": decoded_token["oid"],
"tenant_id": decoded_token["tid"],
"name": decoded_token.get("name", "unknown"),
}
)
except Exception as e:
current_app.logger.error(f"Failed to get created_by info, ignore it: {e}")


def create_app():
app = Flask(__name__)

Expand All @@ -95,7 +58,9 @@ def create_app():
CORS(app)

app.add_url_rule("/heartbeat", view_func=heartbeat)
app.add_url_rule("/v1/traces", view_func=lambda: trace_collector(app.logger), methods=["POST"])
app.add_url_rule(
"/v1/traces", view_func=lambda: trace_collector(get_created_by_info_with_cache, app.logger), methods=["POST"]
)
with app.app_context():
api_v1 = Blueprint("Prompt Flow Service", __name__, url_prefix="/v1.0")

Expand Down Expand Up @@ -182,8 +147,47 @@ def monitor_request():
break

# Retrieve created_by info and cache it in advance to avoid blocking the first request.
retrieve_created_by_info_with_cache()
get_created_by_info_with_cache()
if not sys.executable.endswith("pfcli.exe"):
monitor_thread = ThreadWithContextVars(target=monitor_request, daemon=True)
monitor_thread.start()
return app, api


created_by_for_local_to_cloud_trace = {}
created_by_for_local_to_cloud_trace_lock = threading.Lock()


def get_created_by_info_with_cache():
if len(created_by_for_local_to_cloud_trace) > 0:
return created_by_for_local_to_cloud_trace
trace_provider = Configuration.get_instance().get_trace_provider()
if trace_provider is None or extract_workspace_triad_from_trace_provider(trace_provider) is None:
return
with created_by_for_local_to_cloud_trace_lock:
if len(created_by_for_local_to_cloud_trace) > 0:
return created_by_for_local_to_cloud_trace
try:
# The total time of collecting info is about 3s.
# We may need to run below code more than once
# because user may not run `az login` before running pfs service.
import jwt
from azure.identity import DefaultAzureCredential

from promptflow.azure._utils.general import get_arm_token

default_credential = DefaultAzureCredential()

token = get_arm_token(credential=default_credential)
decoded_token = jwt.decode(token, options={"verify_signature": False})
created_by_for_local_to_cloud_trace.update(
{
"object_id": decoded_token["oid"],
"tenant_id": decoded_token["tid"],
"name": decoded_token.get("name", "unknown"),
}
)
return created_by_for_local_to_cloud_trace
except Exception as e:
# This function is only target to be used in Flask app.
current_app.logger.error(f"Failed to get created_by info, ignore it: {e}")

0 comments on commit cf9f6f8

Please sign in to comment.