diff --git a/src/promptflow/promptflow/_sdk/_service/apis/collector.py b/src/promptflow/promptflow/_sdk/_service/apis/collector.py index 97072da11ef..80cfc60014f 100644 --- a/src/promptflow/promptflow/_sdk/_service/apis/collector.py +++ b/src/promptflow/promptflow/_sdk/_service/apis/collector.py @@ -70,7 +70,7 @@ def trace_collector( for scope_span in resource_span.scope_spans: for span in scope_span.spans: # TODO: persist with batch - span = Span._from_protobuf_object(span, resource=resource) + span = Span._from_protobuf_object(span, resource=resource, logger=logger) if not cloud_trace_only: span._persist() all_spans.append(span) diff --git a/src/promptflow/promptflow/_sdk/_service/apis/line_run.py b/src/promptflow/promptflow/_sdk/_service/apis/line_run.py index 8675d858fda..422b0a9ba8f 100644 --- a/src/promptflow/promptflow/_sdk/_service/apis/line_run.py +++ b/src/promptflow/promptflow/_sdk/_service/apis/line_run.py @@ -86,7 +86,7 @@ def get(self): client: PFClient = get_client_from_request() args = ListLineRunParser.from_request() line_runs: typing.List[LineRunEntity] = client._traces.list_line_runs( - session_id=args.session_id, + collection=args.session_id, runs=args.runs, experiments=args.experiments, trace_ids=args.trace_ids, diff --git a/src/promptflow/promptflow/_sdk/entities/_trace.py b/src/promptflow/promptflow/_sdk/entities/_trace.py index 83c574f8ddb..d4bce4ac7d3 100644 --- a/src/promptflow/promptflow/_sdk/entities/_trace.py +++ b/src/promptflow/promptflow/_sdk/entities/_trace.py @@ -5,6 +5,7 @@ import copy import datetime import json +import logging import typing import uuid from dataclasses import asdict, dataclass @@ -42,6 +43,17 @@ ) +class Event: + @staticmethod + def get(event_id: str) -> typing.Dict: + orm_event = ORMEvent.get(event_id) + data = json.loads(orm_event.data) + # deserialize `events.attributes.payload` here to save effort in UX + payload = data[SpanEventFieldName.ATTRIBUTES][SPAN_EVENTS_ATTRIBUTE_PAYLOAD] + data[SpanEventFieldName.ATTRIBUTES][SPAN_EVENTS_ATTRIBUTE_PAYLOAD] = json.loads(payload) + return data + + class Span: """Span is exactly the same as OpenTelemetry Span.""" @@ -100,10 +112,11 @@ def _persist_events(self) -> None: def _load_events(self) -> None: # load events from table `events` and update `events.attributes` inplace + events = [] for i in range(len(self.events)): event_id = self.events[i][SpanEventFieldName.ATTRIBUTES][SPAN_EVENTS_ATTRIBUTES_EVENT_ID] - orm_event = ORMEvent.get(event_id) - self.events[i] = json.loads(orm_event.data) + events.append(Event.get(event_id=event_id)) + self.events = events def _persist_line_run(self) -> None: # within a trace id, the line run will be created/updated in two cases: @@ -188,11 +201,12 @@ def _from_protobuf_links(obj: typing.List[PBSpan.Link]) -> typing.List[typing.Di return links @staticmethod - def _from_protobuf_object(obj: PBSpan, resource: typing.Dict) -> "Span": + def _from_protobuf_object(obj: PBSpan, resource: typing.Dict, logger: logging.Logger) -> "Span": # Open Telemetry does not provide official way to parse Protocol Buffer Span object # so we need to parse it manually relying on `MessageToJson` # reference: https://github.com/open-telemetry/opentelemetry-python/issues/3700#issuecomment-2010704554 span_dict: dict = json.loads(MessageToJson(obj)) + logger.debug("Received span: %s, resource: %s", json.dumps(span_dict), json.dumps(resource)) span_id = obj.span_id.hex() trace_id = obj.trace_id.hex() parent_id = obj.parent_span_id.hex() diff --git a/src/promptflow/promptflow/_sdk/operations/_trace_operations.py b/src/promptflow/promptflow/_sdk/operations/_trace_operations.py index 6f2b3a98cb3..27422efec10 100644 --- a/src/promptflow/promptflow/_sdk/operations/_trace_operations.py +++ b/src/promptflow/promptflow/_sdk/operations/_trace_operations.py @@ -2,23 +2,16 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -import json import typing -from promptflow._constants import SpanEventFieldName -from promptflow._sdk._constants import SPAN_EVENTS_ATTRIBUTE_PAYLOAD -from promptflow._sdk._orm.trace import Event as ORMEvent from promptflow._sdk._orm.trace import LineRun as ORMLineRun from promptflow._sdk._orm.trace import Span as ORMSpan -from promptflow._sdk.entities._trace import LineRun, Span +from promptflow._sdk.entities._trace import Event, LineRun, Span class TraceOperations: def get_event(self, event_id: str) -> typing.Dict: - data = json.loads(ORMEvent.get(event_id=event_id).data) - payload = data[SpanEventFieldName.ATTRIBUTES][SPAN_EVENTS_ATTRIBUTE_PAYLOAD] - data[SpanEventFieldName.ATTRIBUTES][SPAN_EVENTS_ATTRIBUTE_PAYLOAD] = json.loads(payload) - return data + return Event.get(event_id=event_id) def get_span( self,