Skip to content

Commit

Permalink
[OPIK-187] [SDK] Allow users to configure the project name in track d…
Browse files Browse the repository at this point in the history
…ecorator (#348)

* add new "project name" attr to trace and span data-objects

* add support of "project name" arg to Opik's span() and trace()

* add support of "project name" arg to base decorator

* handle different project names in trace and spans properly (use project name of parent object)

* add support of "project name" arg to opik decorator

* add partial support of "project name" arg to openai decorator

* add e2e tests

* fix linter warnings

* show warning only when the users explicitly specifies a project name for the span that is different from  the parent span project name

* fix linter warnings

* fix the conditions under which the message is displayed
  • Loading branch information
japdubengsub authored Oct 8, 2024
1 parent 3749d11 commit ce4cf84
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 17 deletions.
33 changes: 25 additions & 8 deletions sdks/python/src/opik/api_objects/opik_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from ..message_processing import streamer_constructors, messages, jsonable_encoder
from ..rest_api import client as rest_api_client
from ..rest_api.types import dataset_public, trace_public, span_public
from ..rest_api.types import dataset_public, trace_public, span_public, project_public
from ..rest_api.core.api_error import ApiError
from .. import datetime_helpers, config, httpx_client

Expand Down Expand Up @@ -90,6 +90,7 @@ def trace(
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[List[str]] = None,
feedback_scores: Optional[List[FeedbackScoreDict]] = None,
project_name: Optional[str] = None,
) -> trace.Trace:
"""
Create and log a new trace.
Expand All @@ -103,7 +104,8 @@ def trace(
output: The output data for the trace. This can be any valid JSON serializable object.
metadata: Additional metadata for the trace. This can be any valid JSON serializable object.
tags: Tags associated with the trace.
feedback_scores: The list of feedback score dicts assosiated with the trace. Dicts don't required to have an `id` value.
feedback_scores: The list of feedback score dicts associated with the trace. Dicts don't require to have an `id` value.
project_name: The name of the project.
Returns:
trace.Trace: The created trace object.
Expand All @@ -114,7 +116,7 @@ def trace(
)
create_trace_message = messages.CreateTraceMessage(
trace_id=id,
project_name=self._project_name,
project_name=project_name or self._project_name,
name=name,
start_time=start_time,
end_time=end_time,
Expand All @@ -134,7 +136,7 @@ def trace(
return trace.Trace(
id=id,
message_streamer=self._streamer,
project_name=self._project_name,
project_name=project_name or self._project_name,
)

def span(
Expand All @@ -152,6 +154,7 @@ def span(
tags: Optional[List[str]] = None,
usage: Optional[UsageDict] = None,
feedback_scores: Optional[List[FeedbackScoreDict]] = None,
project_name: Optional[str] = None,
) -> span.Span:
"""
Create and log a new span.
Expand All @@ -169,7 +172,8 @@ def span(
output: The output data for the span. This can be any valid JSON serializable object.
tags: Tags associated with the span.
usage: Usage data for the span.
feedback_scores: The list of feedback score dicts assosiated with the span. Dicts don't required to have an `id` value.
feedback_scores: The list of feedback score dicts associated with the span. Dicts don't require to have an `id` value.
project_name: The name of the project.
Returns:
span.Span: The created span object.
Expand All @@ -193,7 +197,7 @@ def span(
# This version is likely not final.
create_trace_message = messages.CreateTraceMessage(
trace_id=trace_id,
project_name=self._project_name,
project_name=project_name or self._project_name,
name=name,
start_time=start_time,
end_time=end_time,
Expand All @@ -207,7 +211,7 @@ def span(
create_span_message = messages.CreateSpanMessage(
span_id=id,
trace_id=trace_id,
project_name=self._project_name,
project_name=project_name or self._project_name,
parent_span_id=parent_span_id,
name=name,
type=type,
Expand All @@ -231,7 +235,7 @@ def span(
id=id,
parent_span_id=parent_span_id,
trace_id=trace_id,
project_name=self._project_name,
project_name=project_name or self._project_name,
message_streamer=self._streamer,
)

Expand Down Expand Up @@ -478,6 +482,19 @@ def get_span_content(self, id: str) -> span_public.SpanPublic:
"""
return self._rest_client.spans.get_span_by_id(id)

def get_project(self, id: str) -> project_public.ProjectPublic:
"""
Fetches a project by its unique identifier.
Parameters:
id (str): project if (uuid).
Returns:
project_public.ProjectPublic: pydantic model object with all the data associated with the project found.
Raises an error if project was not found
"""
return self._rest_client.projects.get_project_by_id(id)


@functools.lru_cache()
def get_client_cached() -> Opik:
Expand Down
1 change: 1 addition & 0 deletions sdks/python/src/opik/api_objects/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ class SpanData:
tags: Optional[List[str]] = None
usage: Optional[UsageDict] = None
feedback_scores: Optional[List[FeedbackScoreDict]] = None
project_name: Optional[str] = None

def update(self, **new_data: Any) -> "SpanData":
for key, value in new_data.items():
Expand Down
1 change: 1 addition & 0 deletions sdks/python/src/opik/api_objects/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class TraceData:
output: Optional[Dict[str, Any]] = None
tags: Optional[List[str]] = None
feedback_scores: Optional[List[FeedbackScoreDict]] = None
project_name: Optional[str] = None

def update(self, **new_data: Any) -> "TraceData":
for key, value in new_data.items():
Expand Down
1 change: 1 addition & 0 deletions sdks/python/src/opik/decorator/arguments_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ class StartSpanParameters(BaseArguments):
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
input: Optional[Dict[str, Any]] = None
project_name: Optional[str] = None
42 changes: 41 additions & 1 deletion sdks/python/src/opik/decorator/base_track_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def track(
capture_output: bool = True,
generations_aggregator: Optional[Callable[[List[Any]], Any]] = None,
flush: bool = False,
project_name: Optional[str] = None,
) -> Union[Callable, Callable[[Callable], Callable]]:
"""
Decorator to track the execution of a function.
Expand All @@ -57,6 +58,7 @@ def track(
capture_output: Whether to capture the output result.
generations_aggregator: Function to aggregate generation results.
flush: Whether to flush the client after logging.
project_name: The name of the project to log data.
Returns:
Callable: The decorated function(if used without parentheses)
Expand Down Expand Up @@ -84,6 +86,7 @@ def track(
capture_output=capture_output,
generations_aggregator=generations_aggregator,
flush=flush,
project_name=project_name,
)

def decorator(func: Callable) -> Callable:
Expand All @@ -97,6 +100,7 @@ def decorator(func: Callable) -> Callable:
capture_output=capture_output,
generations_aggregator=generations_aggregator,
flush=flush,
project_name=project_name,
)

return decorator
Expand All @@ -112,6 +116,7 @@ def _decorate(
capture_output: bool,
generations_aggregator: Optional[Callable[[List[Any]], Any]],
flush: bool,
project_name: Optional[str],
) -> Callable:
if not inspect_helpers.is_async(func):
return self._tracked_sync(
Expand All @@ -124,6 +129,7 @@ def _decorate(
capture_output=capture_output,
generations_aggregator=generations_aggregator,
flush=flush,
project_name=project_name,
)

return self._tracked_async(
Expand All @@ -136,6 +142,7 @@ def _decorate(
capture_output=capture_output,
generations_aggregator=generations_aggregator,
flush=flush,
project_name=project_name,
)

def _tracked_sync(
Expand All @@ -149,6 +156,7 @@ def _tracked_sync(
capture_output: bool,
generations_aggregator: Optional[Callable[[List[Any]], str]],
flush: bool,
project_name: Optional[str],
) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any: # type: ignore
Expand All @@ -159,6 +167,7 @@ def wrapper(*args, **kwargs) -> Any: # type: ignore
tags=tags,
metadata=metadata,
capture_input=capture_input,
project_name=project_name,
args=args,
kwargs=kwargs,
)
Expand Down Expand Up @@ -205,6 +214,7 @@ def _tracked_async(
capture_output: bool,
generations_aggregator: Optional[Callable[[List[Any]], str]],
flush: bool,
project_name: Optional[str],
) -> Callable:
@functools.wraps(func)
async def wrapper(*args, **kwargs) -> Any: # type: ignore
Expand All @@ -215,6 +225,7 @@ async def wrapper(*args, **kwargs) -> Any: # type: ignore
tags=tags,
metadata=metadata,
capture_input=capture_input,
project_name=project_name,
args=args,
kwargs=kwargs,
)
Expand Down Expand Up @@ -257,6 +268,7 @@ def _before_call(
tags: Optional[List[str]],
metadata: Optional[Dict[str, Any]],
capture_input: bool,
project_name: Optional[str],
args: Tuple,
kwargs: Dict[str, Any],
) -> None:
Expand All @@ -272,6 +284,7 @@ def _before_call(
tags=tags,
metadata=metadata,
capture_input=capture_input,
project_name=project_name,
args=args,
kwargs=kwargs,
)
Expand Down Expand Up @@ -307,6 +320,17 @@ def _create_span(
if current_span_data is not None:
# There is already at least one span in current context.
# Simply attach a new span to it.

if start_span_arguments.project_name != current_span_data.project_name:
if start_span_arguments.project_name is not None:
LOGGER.warning(
"You are attempting to log data into a nested span under "
f'the project name "{start_span_arguments.project_name}". '
f'However, the project name "{current_span_data.project_name}" '
"from parent span will be used instead."
)
start_span_arguments.project_name = current_span_data.project_name

span_data = span.SpanData(
id=helpers.generate_id(),
parent_span_id=current_span_data.id,
Expand All @@ -317,16 +341,27 @@ def _create_span(
tags=start_span_arguments.tags,
metadata=start_span_arguments.metadata,
input=start_span_arguments.input,
project_name=start_span_arguments.project_name,
)
context_storage.add_span_data(span_data)
return

if current_trace_data is not None and current_span_data is None:
# By default we expect trace to be created with a span.
# By default, we expect trace to be created with a span.
# But there can be cases when trace was created and added
# to context manually (not via decorator).
# In that case decorator should just create a span for the existing trace.

if start_span_arguments.project_name != current_trace_data.project_name:
if start_span_arguments.project_name is not None:
LOGGER.warning(
"You are attempting to log data into a nested span under "
f'the project name "{start_span_arguments.project_name}". '
f'However, the project name "{current_trace_data.project_name}" '
"from the trace will be used instead."
)
start_span_arguments.project_name = current_trace_data.project_name

span_data = span.SpanData(
id=helpers.generate_id(),
parent_span_id=None,
Expand All @@ -337,6 +372,7 @@ def _create_span(
tags=start_span_arguments.tags,
metadata=start_span_arguments.metadata,
input=start_span_arguments.input,
project_name=start_span_arguments.project_name,
)
context_storage.add_span_data(span_data)
return
Expand All @@ -351,6 +387,7 @@ def _create_span(
input=start_span_arguments.input,
metadata=start_span_arguments.metadata,
tags=start_span_arguments.tags,
project_name=start_span_arguments.project_name,
)
TRACES_CREATED_BY_DECORATOR.add(trace_data.id)

Expand All @@ -364,6 +401,7 @@ def _create_span(
tags=start_span_arguments.tags,
metadata=start_span_arguments.metadata,
input=start_span_arguments.input,
project_name=start_span_arguments.project_name,
)

context_storage.set_trace_data(trace_data)
Expand All @@ -384,6 +422,7 @@ def _create_distributed_node_root_span(
metadata=start_span_arguments.metadata,
tags=start_span_arguments.tags,
type=start_span_arguments.type,
project_name=start_span_arguments.project_name,
)
context_storage.add_span_data(span_data)

Expand Down Expand Up @@ -481,6 +520,7 @@ def _start_span_inputs_preprocessor(
capture_input: bool,
args: Tuple,
kwargs: Dict[str, Any],
project_name: Optional[str],
) -> arguments_helpers.StartSpanParameters: ...

@abc.abstractmethod
Expand Down
8 changes: 7 additions & 1 deletion sdks/python/src/opik/decorator/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def _start_span_inputs_preprocessor(
capture_input: bool,
args: Tuple,
kwargs: Dict[str, Any],
project_name: Optional[str],
) -> arguments_helpers.StartSpanParameters:
input = (
inspect_helpers.extract_inputs(func, args, kwargs)
Expand All @@ -36,7 +37,12 @@ def _start_span_inputs_preprocessor(
name = name if name is not None else func.__name__

result = arguments_helpers.StartSpanParameters(
name=name, input=input, type=type, tags=tags, metadata=metadata
name=name,
input=input,
type=type,
tags=tags,
metadata=metadata,
project_name=project_name,
)

return result
Expand Down
8 changes: 7 additions & 1 deletion sdks/python/src/opik/integrations/openai/openai_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _start_span_inputs_preprocessor(
capture_input: bool,
args: Optional[Tuple],
kwargs: Optional[Dict[str, Any]],
project_name: Optional[str],
) -> arguments_helpers.StartSpanParameters:
assert (
kwargs is not None
Expand All @@ -55,7 +56,12 @@ def _start_span_inputs_preprocessor(
tags = ["openai"]

result = arguments_helpers.StartSpanParameters(
name=name, input=input, type=type, tags=tags, metadata=metadata
name=name,
input=input,
type=type,
tags=tags,
metadata=metadata,
project_name=project_name,
)

return result
Expand Down
5 changes: 4 additions & 1 deletion sdks/python/tests/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import os
import random
import string
from typing import Final

import opik
import opik.api_objects.opik_client

import pytest

OPIK_E2E_TESTS_PROJECT_NAME: Final[str] = "e2e-tests"


def _random_chars(n: int = 6) -> str:
return "".join(random.choice(string.ascii_letters) for _ in range(n))


@pytest.fixture(scope="session")
def configure_e2e_tests_env():
os.environ["OPIK_PROJECT_NAME"] = "e2e-tests"
os.environ["OPIK_PROJECT_NAME"] = OPIK_E2E_TESTS_PROJECT_NAME


@pytest.fixture()
Expand Down
Loading

0 comments on commit ce4cf84

Please sign in to comment.