Skip to content

Commit

Permalink
Change to also retrieve created by info during handling request, sinc…
Browse files Browse the repository at this point in the history
…e customer may not run `az login`
  • Loading branch information
robbenwang committed Mar 8, 2024
1 parent fdb286d commit 626ad5b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 33 deletions.
11 changes: 8 additions & 3 deletions src/promptflow/promptflow/_sdk/_service/apis/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger):

logger.info(f"Start writing trace to cosmosdb, total spans count: {len(all_spans)}.")
start_time = datetime.now()
from promptflow._sdk._service.app import CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE
from promptflow._sdk._service.app import retrieve_created_by_info_with_cache
from promptflow.azure._storage.cosmosdb.client import get_client
from promptflow.azure._storage.cosmosdb.span import Span as SpanCosmosDB
from promptflow.azure._storage.cosmosdb.summary import Summary
Expand All @@ -92,17 +92,22 @@ def _try_write_trace_to_cosmosdb(all_spans, logger: logging.Logger):

get_client(CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name)

# Retrieve created_by info, we already get it in advance as starting the service.
# But user may not run `az login` before running pfs service.
# So we need to call this function again to get the created_by info.
created_by = retrieve_created_by_info_with_cache()

span_thread.join()

for span in all_spans:
span_client = get_client(CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name)
result = SpanCosmosDB(span, CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE).persist(span_client)
result = SpanCosmosDB(span, created_by).persist(span_client)
# None means the span already exists, then we don't need to persist the summary also.
if result is not None:
line_summary_client = get_client(
CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name
)
Summary(span, CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE, logger).persist(line_summary_client)
Summary(span, created_by, logger).persist(line_summary_client)
logger.info(
(
f"Finish writing trace to cosmosdb, total spans count: {len(all_spans)}."
Expand Down
75 changes: 45 additions & 30 deletions src/promptflow/promptflow/_sdk/_service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# ---------------------------------------------------------
import logging
import sys
import threading
import time
from datetime import datetime, timedelta
from logging.handlers import RotatingFileHandler
Expand All @@ -11,6 +12,7 @@
from flask_cors import CORS
from werkzeug.exceptions import HTTPException

from promptflow._sdk._configuration import Configuration
from promptflow._sdk._constants import (
HOME_PROMPT_FLOW_DIR,
PF_SERVICE_HOUR_TIMEOUT,
Expand All @@ -31,7 +33,12 @@
get_port_from_config,
kill_exist_service,
)
from promptflow._sdk._utils import get_promptflow_sdk_version, overwrite_null_std_logger, read_write_by_user
from promptflow._sdk._utils import (
extract_workspace_triad_from_trace_provider,
get_promptflow_sdk_version,
overwrite_null_std_logger,
read_write_by_user,
)
from promptflow._utils.thread_utils import ThreadWithContextVars

overwrite_null_std_logger()
Expand All @@ -42,7 +49,41 @@ def heartbeat():
return jsonify(response)


CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE = {}
created_by_for_local_to_cloud_trace = {}
created_by_for_local_to_cloud_trace_lock = threading.Lock()


def retrieve_created_by_info_with_cache():
if len(created_by_for_local_to_cloud_trace) > 0:
return created_by_for_local_to_cloud_trace
trace_provider = Configuration.get_instance().get_trace_provider()
if trace_provider is None or extract_workspace_triad_from_trace_provider(trace_provider) is None:
return
with created_by_for_local_to_cloud_trace_lock:
if len(created_by_for_local_to_cloud_trace) > 0:
return created_by_for_local_to_cloud_trace
try:
# The total time of collecting info is about 3s.
# We may need to run below code more than once
# because user may not run `az login` before running pfs service.
import jwt
from azure.identity import DefaultAzureCredential

from promptflow.azure._utils.general import get_arm_token

default_credential = DefaultAzureCredential()

token = get_arm_token(credential=default_credential)
decoded_token = jwt.decode(token, options={"verify_signature": False})
created_by_for_local_to_cloud_trace.update(
{
"object_id": decoded_token["oid"],
"tenant_id": decoded_token["tid"],
"name": decoded_token.get("name", "unknown"),
}
)
except Exception as e:
current_app.logger.error(f"Failed to get created_by info, ignore it: {e}")


def create_app():
Expand Down Expand Up @@ -86,33 +127,6 @@ def create_app():
# Set app logger to the only one RotatingFileHandler to avoid duplicate logs
app.logger.handlers = [handler]

def initialize_created_by_info():
from promptflow._sdk._configuration import Configuration
from promptflow._sdk._utils import extract_workspace_triad_from_trace_provider

trace_provider = Configuration.get_instance().get_trace_provider()
if trace_provider is None or extract_workspace_triad_from_trace_provider(trace_provider) is None:
return
try:
import jwt
from azure.identity import DefaultAzureCredential

from promptflow.azure._utils.general import get_arm_token

default_credential = DefaultAzureCredential()

token = get_arm_token(credential=default_credential)
decoded_token = jwt.decode(token, options={"verify_signature": False})
user_object_id, user_tenant_id = decoded_token["oid"], decoded_token["tid"]
CREATED_BY_FOR_LOCAL_TO_CLOUD_TRACE.update(
{
"object_id": user_object_id,
"tenant_id": user_tenant_id,
}
)
except Exception as e:
current_app.logger.error(f"Failed to get created_by info, ignore it: {e}")

# Basic error handler
@api.errorhandler(Exception)
def handle_exception(e):
Expand Down Expand Up @@ -167,7 +181,8 @@ def monitor_request():
kill_exist_service(port)
break

initialize_created_by_info()
# Retrieve created_by info and cache it in advance to avoid blocking the first request.
retrieve_created_by_info_with_cache()
if not sys.executable.endswith("pfcli.exe"):
monitor_thread = ThreadWithContextVars(target=monitor_request, daemon=True)
monitor_thread.start()
Expand Down

0 comments on commit 626ad5b

Please sign in to comment.