Skip to content

Commit

Permalink
[Internal] Upload the events in span to Blob. (#2479)
Browse files Browse the repository at this point in the history
# Description

In this PR, we upload the events in span to blob.
- Init the BlobClient.
- The azureml uri has pattern
"azureml://subscriptions/$sub/resourcegroups/$rg/workspaces/$ws/datastores/$storename/paths/.promptflow/.trace/$collection_id/$trace_id/$span_id/$id"
- We cannot get the datastore name from blob client. Blob client only
has container name, so we need to get the base uri including datastore
name in collector, and pass it to span client.
  - Get default datastore may take 2~5s. So, we add a cache.
- Upload event to blob and add the azureml uri to
"external_event_data_uris".
- Clear the "attributes" in event to reduce the span size.

Test case: [query
link](https://int.ml.azure.com/prompts/trace/list?wsid=%2Fsubscriptions%2F96aede12-2f73-41cb-b983-6d11a904839b%2FresourceGroups%2Fpromptflow%2Fproviders%2FMicrosoft.MachineLearningServices%2Fworkspaces%2Fpromptflow-eastus-dev&searchText=%7B%22batchRunId%22%3A%22chat_with_pdf_variant_0_20240328_204721_774665%22%7D&tid=72f988bf-86f1-41af-91ab-2d7cd011db47)


![image](https://github.com/microsoft/promptflow/assets/2418764/8dae02b7-a793-44b0-b31b-d25cee8b4bec)

Remainder: To upload data to blob, user need to have the role "Storage
Blob Data Contributor" to the storage account of workspace.

---------

Co-authored-by: Yangtong Xu <[email protected]>
  • Loading branch information
riddlexu and Yangtong Xu authored Mar 29, 2024
1 parent f3edcb8 commit 0975987
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
113 changes: 113 additions & 0 deletions src/promptflow-azure/promptflow/azure/_storage/blob/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import datetime
import logging
import threading
import traceback
from typing import Optional, Tuple

from azure.ai.ml import MLClient
from azure.ai.ml._azure_environments import _get_storage_endpoint_from_metadata
from azure.ai.ml._restclient.v2022_10_01.models import DatastoreType
from azure.ai.ml.constants._common import LONG_URI_FORMAT, STORAGE_ACCOUNT_URLS
from azure.ai.ml.entities._datastore.datastore import Datastore
from azure.storage.blob import ContainerClient

from promptflow.exceptions import UserErrorException

_datastore_cache = {}
_thread_lock = threading.Lock()
_cache_timeout = 60 * 4 # Align the cache ttl with cosmosdb client.


def get_datastore_container_client(
logger: logging.Logger,
subscription_id: str,
resource_group_name: str,
workspace_name: str,
credential: Optional[object] = None,
) -> Tuple[ContainerClient, str]:
try:
# To write data to blob, user should have "Storage Blob Data Contributor" to the storage account.
if credential is None:
from azure.identity import DefaultAzureCredential

credential = DefaultAzureCredential()

default_datastore = get_default_datastore(subscription_id, resource_group_name, workspace_name, credential)

storage_endpoint = _get_storage_endpoint_from_metadata()
account_url = STORAGE_ACCOUNT_URLS[DatastoreType.AZURE_BLOB].format(
default_datastore.account_name, storage_endpoint
)

# Datastore is a notion of AzureML, it is not a notion of Blob Storage.
# So, we cannot get datastore name by blob client.
# To generate the azureml uri has datastore name, we need to generate the uri here and pass in to db client.
container_client = ContainerClient(
account_url=account_url, container_name=default_datastore.container_name, credential=credential
)
blob_base_uri = LONG_URI_FORMAT.format(
subscription_id, resource_group_name, workspace_name, default_datastore.name, ""
)
if not blob_base_uri.endswith("/"):
blob_base_uri += "/"

logger.info(f"Get blob base url for {blob_base_uri}")

return container_client, blob_base_uri

except Exception as e:
stack_trace = traceback.format_exc()
logger.error(f"Failed to get blob client: {e}, stack trace is {stack_trace}")
raise


def get_default_datastore(
subscription_id: str, resource_group_name: str, workspace_name: str, credential: Optional[object]
) -> Datastore:

datastore_key = _get_datastore_client_key(subscription_id, resource_group_name, workspace_name)
datastore = _get_datastore_from_cache(datastore_key=datastore_key)
if datastore is None:
with _thread_lock:
datastore = _get_datastore_from_cache(datastore_key=datastore_key)
if datastore is None:
datastore = _get_default_datastore(subscription_id, resource_group_name, workspace_name, credential)
_datastore_cache[datastore_key] = {
"expire_at": datetime.datetime.now() + datetime.timedelta(seconds=_cache_timeout),
"datastore": datastore,
}
return datastore


def _get_datastore_from_cache(datastore_key: str):
datastore = _datastore_cache.get(datastore_key)

if datastore and datastore["expire_at"] > datetime.datetime.now():
return datastore["datastore"]

return None


def _get_datastore_client_key(subscription_id: str, resource_group_name: str, workspace_name: str) -> str:
# Azure name allow hyphens and underscores. User @ to avoid possible conflict.
return f"{subscription_id}@{resource_group_name}@{workspace_name}"


def _get_default_datastore(
subscription_id: str, resource_group_name: str, workspace_name: str, credential: Optional[object]
) -> Datastore:

ml_client = MLClient(
credential=credential,
subscription_id=subscription_id,
resource_group_name=resource_group_name,
workspace_name=workspace_name,
)

default_datastore = ml_client.datastores.get_default()
if default_datastore.type != DatastoreType.AZURE_BLOB:
raise UserErrorException(
message=f"Default datastore {default_datastore.name} is {default_datastore.type}, not AzureBlob."
)

return default_datastore
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,5 @@ def _init_container_client(endpoint: str, database_name: str, container_name: st


def _get_db_client_key(container_name: str, subscription_id: str, resource_group_name: str, workspace_name: str) -> str:
return f"{subscription_id}_{resource_group_name}_{workspace_name}_{container_name}"
# Azure name allow hyphens and underscores. User @ to avoid possible conflict.
return f"{subscription_id}@{resource_group_name}@{workspace_name}@{container_name}"
33 changes: 28 additions & 5 deletions src/promptflow-azure/promptflow/azure/_storage/cosmosdb/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import json
from typing import Any, Dict

from promptflow._constants import SpanFieldName
from azure.cosmos.container import ContainerProxy
from azure.storage.blob import ContainerClient

from promptflow._constants import SpanContextFieldName, SpanEventFieldName, SpanFieldName
from promptflow._sdk.entities._trace import Span as SpanEntity


class Span:

name: str = None
context: dict = None
kind: str = None
Expand All @@ -25,6 +28,7 @@ class Span:
partition_key: str = None
collection_id: str = None
created_by: dict = None
external_event_data_uris: list = None

def __init__(self, span: SpanEntity, collection_id: str, created_by: dict) -> None:
self.name = span.name
Expand All @@ -42,21 +46,40 @@ def __init__(self, span: SpanEntity, collection_id: str, created_by: dict) -> No
self.collection_id = collection_id
self.id = span.span_id
self.created_by = created_by
self.external_event_data_uris = []

def persist(self, client):
def persist(self, cosmos_client: ContainerProxy, blob_container_client: ContainerClient, blob_base_uri: str):
if self.id is None or self.partition_key is None or self.resource is None:
return

resource_attributes = self.resource.get(SpanFieldName.ATTRIBUTES, None)
if resource_attributes is None:
return

if self.events and blob_container_client is not None and blob_base_uri is not None:
self._persist_events(blob_container_client, blob_base_uri)

from azure.cosmos.exceptions import CosmosResourceExistsError

try:
return client.create_item(body=self.to_dict())
return cosmos_client.create_item(body=self.to_dict())
except CosmosResourceExistsError:
return None
return

def to_dict(self) -> Dict[str, Any]:
return {k: v for k, v in self.__dict__.items() if v}

def _persist_events(self, blob_container_client: ContainerClient, blob_base_uri: str):
for idx, event in enumerate(self.events):
event_data = json.dumps(event)
blob_client = blob_container_client.get_blob_client(self._event_path(idx))
blob_client.upload_blob(event_data)

event[SpanEventFieldName.ATTRIBUTES] = {}
self.external_event_data_uris.append(f"{blob_base_uri}{self._event_path(idx)}")

EVENT_PATH_PREFIX = ".promptflow/.trace"

def _event_path(self, idx: int) -> str:
trace_id = self.context[SpanContextFieldName.TRACE_ID]
return f"{self.EVENT_PATH_PREFIX}/{self.collection_id}/{trace_id}/{self.id}/{idx}"
28 changes: 23 additions & 5 deletions src/promptflow-devkit/promptflow/_sdk/_service/apis/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,13 @@ def _try_write_trace_to_cosmosdb(

logger.info(f"Start writing trace to cosmosdb, total spans count: {len(all_spans)}.")
start_time = datetime.now()

from promptflow.azure._storage.cosmosdb.client import get_client
from promptflow.azure._storage.cosmosdb.collection import CollectionCosmosDB
from promptflow.azure._storage.cosmosdb.span import Span as SpanCosmosDB
from promptflow.azure._storage.cosmosdb.summary import Summary

# Load span and summary clients first time may slow.
# Load span, collection and summary clients first time may slow.
# So, we load clients in parallel for warm up.
span_client_thread = ThreadWithContextVars(
target=get_client,
Expand All @@ -134,18 +135,33 @@ def _try_write_trace_to_cosmosdb(
)
collection_client_thread.start()

line_summary_client_thread = ThreadWithContextVars(
target=get_client,
args=(CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name, credential),
)
line_summary_client_thread.start()

# Load created_by info first time may slow. So, we load it in parallel for warm up.
created_by_thread = ThreadWithContextVars(target=get_created_by_info_with_cache)
created_by_thread.start()

get_client(CosmosDBContainerName.LINE_SUMMARY, subscription_id, resource_group_name, workspace_name, credential)
# Get default blob may be slow. So, we have a cache for default datastore.
from promptflow.azure._storage.blob.client import get_datastore_container_client

blob_container_client, blob_base_uri = get_datastore_container_client(
logger=logger,
subscription_id=subscription_id,
resource_group_name=resource_group_name,
workspace_name=workspace_name,
credential=credential,
)

span_client_thread.join()
created_by_thread.join()
collection_client_thread.join()
line_summary_client_thread.join()
created_by_thread.join()

created_by = get_created_by_info_with_cache()

collection_client = get_client(
CosmosDBContainerName.COLLECTION, subscription_id, resource_group_name, workspace_name, credential
)
Expand All @@ -158,7 +174,9 @@ def _try_write_trace_to_cosmosdb(
span_client = get_client(
CosmosDBContainerName.SPAN, subscription_id, resource_group_name, workspace_name, credential
)
result = SpanCosmosDB(span, collection_id, created_by).persist(span_client)
result = SpanCosmosDB(span, collection_id, created_by).persist(
span_client, blob_container_client, blob_base_uri
)
# None means the span already exists, then we don't need to persist the summary also.
if result is not None:
line_summary_client = get_client(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import datetime

import pytest

from promptflow.azure._storage.blob.client import _datastore_cache, _get_datastore_client_key, _get_datastore_from_cache


@pytest.mark.unittest
class TestBlobClient:
def test_get_datastore_from_cache(self):
_datastore_cache["test"] = {
"expire_at": datetime.datetime.now() + datetime.timedelta(0, -1), # already expire
"datastore": "test",
}
assert _get_datastore_from_cache("test") is None

_datastore_cache["test"] = {
"expire_at": datetime.datetime.now() + datetime.timedelta(1, 0), # expire after 1 day
"datastore": "test",
}
assert _get_datastore_from_cache("test") == "test"

def test_get_datastore_client_key(self):
assert _get_datastore_client_key("sub", "rg", "ws") == "sub@rg@ws"
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import pytest

from promptflow.azure._storage.cosmosdb.client import _get_client_from_map, _get_container_lock, client_map
from promptflow.azure._storage.cosmosdb.client import (
_get_client_from_map,
_get_container_lock,
_get_db_client_key,
client_map,
)


@pytest.mark.unittest
Expand All @@ -25,3 +30,6 @@ def test_get_container_lock(self):
assert container_lock is not None
assert _get_container_lock("test2") != container_lock
assert _get_container_lock("test") == container_lock

def test_get_db_client_key(self):
assert _get_db_client_key("container", "sub", "rg", "ws") == "sub@rg@ws@container"
51 changes: 41 additions & 10 deletions src/promptflow/tests/sdk_cli_azure_test/unittests/test_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
class TestSpan:
FAKE_CREATED_BY = {"oid": "fake_oid"}
FAKE_COLLECTION_ID = "fake_collection_id"
FAKE_TRACE_ID = "0xacf2291a630af328da8fabd6bf49f653"
FAKE_SPAN_ID = "0x9ded7ce65d5f7775"

def test_to_dict(self):
span = Span(
SpanEntity(
name="test",
context={
"trace_id": "0xacf2291a630af328da8fabd6bf49f653",
"span_id": "0x9ded7ce65d5f7775",
"trace_id": self.FAKE_TRACE_ID,
"span_id": self.FAKE_SPAN_ID,
},
kind="test",
parent_span_id="test",
Expand All @@ -39,10 +41,10 @@ def test_to_dict(self):
"start_time": "test",
"end_time": "test",
"context": {
"trace_id": "0xacf2291a630af328da8fabd6bf49f653",
"span_id": "0x9ded7ce65d5f7775",
"trace_id": self.FAKE_TRACE_ID,
"span_id": self.FAKE_SPAN_ID,
},
"id": "0x9ded7ce65d5f7775",
"id": self.FAKE_SPAN_ID,
"collection_id": "fake_collection_id",
"created_by": {"oid": "fake_oid"},
}
Expand All @@ -51,8 +53,8 @@ def test_to_dict(self):
SpanEntity(
name="test",
context={
"trace_id": "0xacf2291a630af328da8fabd6bf49f653",
"span_id": "0x9ded7ce65d5f7775",
"trace_id": self.FAKE_TRACE_ID,
"span_id": self.FAKE_SPAN_ID,
},
kind="test",
parent_span_id="test",
Expand All @@ -78,11 +80,40 @@ def test_to_dict(self):
"attributes": {"line_run_id": "test_line_run_id"},
"partition_key": "test_session_id",
"context": {
"trace_id": "0xacf2291a630af328da8fabd6bf49f653",
"span_id": "0x9ded7ce65d5f7775",
"trace_id": self.FAKE_TRACE_ID,
"span_id": self.FAKE_SPAN_ID,
},
"id": "0x9ded7ce65d5f7775",
"id": self.FAKE_SPAN_ID,
"partition_key": "test_session_id",
"collection_id": "fake_collection_id",
"created_by": {"oid": "fake_oid"},
}

def test_event_path(self):
span = Span(
SpanEntity(
name="test",
context={
"trace_id": self.FAKE_TRACE_ID,
"span_id": self.FAKE_SPAN_ID,
},
kind="test",
parent_span_id="test",
start_time="test",
end_time="test",
status={},
attributes={},
events=[],
links=[],
resource={},
span_type=None,
session_id=None,
),
collection_id=self.FAKE_COLLECTION_ID,
created_by=self.FAKE_CREATED_BY,
)

assert (
span._event_path(1)
== f".promptflow/.trace/{self.FAKE_COLLECTION_ID}/{self.FAKE_TRACE_ID}/{self.FAKE_SPAN_ID}/1"
)

0 comments on commit 0975987

Please sign in to comment.