Skip to content

Commit

Permalink
support lazy load
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengfeiwang committed Mar 27, 2024
1 parent cbe2674 commit 54a71d9
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/_sdk/_service/apis/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def trace_collector(
for scope_span in resource_span.scope_spans:
for span in scope_span.spans:
# TODO: persist with batch
span = Span._from_protobuf_object(span, resource=resource)
span = Span._from_protobuf_object(span, resource=resource, logger=logger)
if not cloud_trace_only:
span._persist()
all_spans.append(span)
Expand Down
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/_sdk/_service/apis/line_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get(self):
client: PFClient = get_client_from_request()
args = ListLineRunParser.from_request()
line_runs: typing.List[LineRunEntity] = client._traces.list_line_runs(
session_id=args.session_id,
collection=args.session_id,
runs=args.runs,
experiments=args.experiments,
trace_ids=args.trace_ids,
Expand Down
20 changes: 17 additions & 3 deletions src/promptflow/promptflow/_sdk/entities/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import datetime
import json
import logging
import typing
import uuid
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -42,6 +43,17 @@
)


class Event:
@staticmethod
def get(event_id: str) -> typing.Dict:
orm_event = ORMEvent.get(event_id)
data = json.loads(orm_event.data)
# deserialize `events.attributes.payload` here to save effort in UX
payload = data[SpanEventFieldName.ATTRIBUTES][SPAN_EVENTS_ATTRIBUTE_PAYLOAD]
data[SpanEventFieldName.ATTRIBUTES][SPAN_EVENTS_ATTRIBUTE_PAYLOAD] = json.loads(payload)
return data


class Span:
"""Span is exactly the same as OpenTelemetry Span."""

Expand Down Expand Up @@ -100,10 +112,11 @@ def _persist_events(self) -> None:

def _load_events(self) -> None:
# load events from table `events` and update `events.attributes` inplace
events = []
for i in range(len(self.events)):
event_id = self.events[i][SpanEventFieldName.ATTRIBUTES][SPAN_EVENTS_ATTRIBUTES_EVENT_ID]
orm_event = ORMEvent.get(event_id)
self.events[i] = json.loads(orm_event.data)
events.append(Event.get(event_id=event_id))
self.events = events

def _persist_line_run(self) -> None:
# within a trace id, the line run will be created/updated in two cases:
Expand Down Expand Up @@ -188,11 +201,12 @@ def _from_protobuf_links(obj: typing.List[PBSpan.Link]) -> typing.List[typing.Di
return links

@staticmethod
def _from_protobuf_object(obj: PBSpan, resource: typing.Dict) -> "Span":
def _from_protobuf_object(obj: PBSpan, resource: typing.Dict, logger: logging.Logger) -> "Span":
# Open Telemetry does not provide official way to parse Protocol Buffer Span object
# so we need to parse it manually relying on `MessageToJson`
# reference: https://github.com/open-telemetry/opentelemetry-python/issues/3700#issuecomment-2010704554
span_dict: dict = json.loads(MessageToJson(obj))
logger.debug("Received span: %s, resource: %s", json.dumps(span_dict), json.dumps(resource))
span_id = obj.span_id.hex()
trace_id = obj.trace_id.hex()
parent_id = obj.parent_span_id.hex()
Expand Down
11 changes: 2 additions & 9 deletions src/promptflow/promptflow/_sdk/operations/_trace_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,16 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import json
import typing

from promptflow._constants import SpanEventFieldName
from promptflow._sdk._constants import SPAN_EVENTS_ATTRIBUTE_PAYLOAD
from promptflow._sdk._orm.trace import Event as ORMEvent
from promptflow._sdk._orm.trace import LineRun as ORMLineRun
from promptflow._sdk._orm.trace import Span as ORMSpan
from promptflow._sdk.entities._trace import LineRun, Span
from promptflow._sdk.entities._trace import Event, LineRun, Span


class TraceOperations:
def get_event(self, event_id: str) -> typing.Dict:
data = json.loads(ORMEvent.get(event_id=event_id).data)
payload = data[SpanEventFieldName.ATTRIBUTES][SPAN_EVENTS_ATTRIBUTE_PAYLOAD]
data[SpanEventFieldName.ATTRIBUTES][SPAN_EVENTS_ATTRIBUTE_PAYLOAD] = json.loads(payload)
return data
return Event.get(event_id=event_id)

def get_span(
self,
Expand Down

0 comments on commit 54a71d9

Please sign in to comment.