diff --git a/src/promptflow/promptflow/_sdk/_service/apis/collector.py b/src/promptflow/promptflow/_sdk/_service/apis/collector.py index 0a31761074b..647f8a64a11 100644 --- a/src/promptflow/promptflow/_sdk/_service/apis/collector.py +++ b/src/promptflow/promptflow/_sdk/_service/apis/collector.py @@ -11,6 +11,7 @@ import logging import traceback from datetime import datetime +from typing import Callable from flask import request from google.protobuf.json_format import MessageToJson @@ -28,7 +29,15 @@ from promptflow._utils.thread_utils import ThreadWithContextVars -def trace_collector(logger: logging.Logger): +def trace_collector(get_created_by_info_with_cache: Callable, logger: logging.Logger): + """ + This function is target to be reused in other places, so pass in get_created_by_info_with_cache and logger to avoid + app related dependencies. + + Args: + get_created_by_info_with_cache (Callable): A function that retrieves information about the creator of the trace. + logger (logging.Logger): The logger object used for logging. + """ content_type = request.headers.get("Content-Type") # binary protobuf encoding if "application/x-protobuf" in content_type: @@ -55,7 +64,9 @@ def trace_collector(logger: logging.Logger): all_spans.append(span) # Create a new thread to write trace to cosmosdb to avoid blocking the main thread - ThreadWithContextVars(target=_try_write_trace_to_cosmosdb, args=(all_spans, logger)).start() + ThreadWithContextVars( + target=_try_write_trace_to_cosmosdb, args=(all_spans, get_created_by_info_with_cache, logger) + ).start() return "Traces received", 200 # JSON protobuf encoding @@ -63,7 +74,7 @@ def trace_collector(logger: logging.Logger): raise NotImplementedError -def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger): +def _try_write_trace_to_cosmosdb(all_spans, get_created_by_info_with_cache: Callable, logger: logging.Logger): if not all_spans: return try: @@ -78,31 +89,37 @@ def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger): logger.info(f"Start writing trace to cosmosdb, total spans count: {len(all_spans)}.") start_time = datetime.now() - from promptflow._sdk._service.app import CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE from promptflow.azure._storage.cosmosdb.client import get_client from promptflow.azure._storage.cosmosdb.span import Span as SpanCosmosDB from promptflow.azure._storage.cosmosdb.summary import Summary # Load span and summary clients first time may slow. # So, we load 2 client in parallel for warm up. - span_thread = ThreadWithContextVars( + span_client_thread = ThreadWithContextVars( target=get_client, args=(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name) ) - span_thread.start() + span_client_thread.start() + + # Load created_by info first time may slow. So, we load it in parallel for warm up. + created_by_thread = ThreadWithContextVars(target=get_created_by_info_with_cache) + created_by_thread.start() get_client(CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name) - span_thread.join() + span_client_thread.join() + created_by_thread.join() + + created_by = get_created_by_info_with_cache() for span in all_spans: span_client = get_client(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name) - result = SpanCosmosDB(span, CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE).persist(span_client) + result = SpanCosmosDB(span, created_by).persist(span_client) # None means the span already exists, then we don't need to persist the summary also. if result is not None: line_summary_client = get_client( CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name ) - Summary(span, CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE, logger).persist(line_summary_client) + Summary(span, created_by, logger).persist(line_summary_client) logger.info( ( f"Finish writing trace to cosmosdb, total spans count: {len(all_spans)}." diff --git a/src/promptflow/promptflow/_sdk/_service/app.py b/src/promptflow/promptflow/_sdk/_service/app.py index 30993d0da05..d2555465b96 100644 --- a/src/promptflow/promptflow/_sdk/_service/app.py +++ b/src/promptflow/promptflow/_sdk/_service/app.py @@ -3,6 +3,7 @@ # --------------------------------------------------------- import logging import sys +import threading import time from datetime import datetime, timedelta from logging.handlers import RotatingFileHandler @@ -42,9 +43,6 @@ def heartbeat(): return jsonify(response) -CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE = {} - - def create_app(): app = Flask(__name__) @@ -54,7 +52,9 @@ def create_app(): CORS(app) app.add_url_rule("/heartbeat", view_func=heartbeat) - app.add_url_rule("/v1/traces", view_func=lambda: trace_collector(app.logger), methods=["POST"]) + app.add_url_rule( + "/v1/traces", view_func=lambda: trace_collector(get_created_by_info_with_cache, app.logger), methods=["POST"] + ) with app.app_context(): api_v1 = Blueprint("Prompt Flow Service", __name__, url_prefix="/v1.0") @@ -86,33 +86,6 @@ def create_app(): # Set app logger to the only one RotatingFileHandler to avoid duplicate logs app.logger.handlers = [handler] - def initialize_created_by_info(): - from promptflow._sdk._configuration import Configuration - from promptflow._sdk._utils import extract_workspace_triad_from_trace_provider - - trace_provider = Configuration.get_instance().get_trace_provider() - if trace_provider is None or extract_workspace_triad_from_trace_provider(trace_provider) is None: - return - try: - import jwt - from azure.identity import DefaultAzureCredential - - from promptflow.azure._utils.general import get_arm_token - - default_credential = DefaultAzureCredential() - - token = get_arm_token(credential=default_credential) - decoded_token = jwt.decode(token, options={"verify_signature": False}) - user_object_id, user_tenant_id = decoded_token["oid"], decoded_token["tid"] - CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE.update( - { - "object_id": user_object_id, - "tenant_id": user_tenant_id, - } - ) - except Exception as e: - current_app.logger.error(f"Failed to get created_by info, ignore it: {e}") - # Basic error handler @api.errorhandler(Exception) def handle_exception(e): @@ -167,8 +140,42 @@ def monitor_request(): kill_exist_service(port) break - initialize_created_by_info() if not sys.executable.endswith("pfcli.exe"): monitor_thread = ThreadWithContextVars(target=monitor_request, daemon=True) monitor_thread.start() return app, api + + +created_by_for_local_to_cloud_trace = {} +created_by_for_local_to_cloud_trace_lock = threading.Lock() + + +def get_created_by_info_with_cache(): + if len(created_by_for_local_to_cloud_trace) > 0: + return created_by_for_local_to_cloud_trace + with created_by_for_local_to_cloud_trace_lock: + if len(created_by_for_local_to_cloud_trace) > 0: + return created_by_for_local_to_cloud_trace + try: + # The total time of collecting info is about 3s. + import jwt + from azure.identity import DefaultAzureCredential + + from promptflow.azure._utils.general import get_arm_token + + default_credential = DefaultAzureCredential() + + token = get_arm_token(credential=default_credential) + decoded_token = jwt.decode(token, options={"verify_signature": False}) + created_by_for_local_to_cloud_trace.update( + { + "object_id": decoded_token["oid"], + "tenant_id": decoded_token["tid"], + # Use appid as fallback for service principal scenario. + "name": decoded_token.get("name", decoded_token.get("appid", "")), + } + ) + except Exception as e: + # This function is only target to be used in Flask app. + current_app.logger.error(f"Failed to get created_by info, ignore it: {e}") + return created_by_for_local_to_cloud_trace