Skip to content

Commit

Permalink
OPIK-186 [SDK] Allow users to configure the project name in LangChain…
Browse files Browse the repository at this point in the history
… integration (#374)

* OPIK-186 [SDK] Allow users to configure the project name in LangChain integration

* fix linter warning
  • Loading branch information
japdubengsub authored Oct 14, 2024
1 parent e220b8d commit 4330b09
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 8 deletions.
43 changes: 41 additions & 2 deletions sdks/python/src/opik/integrations/langchain/opik_tracer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Literal, Set
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Literal, Set, Union

from langchain_core.tracers import BaseTracer

Expand All @@ -14,6 +15,8 @@

from langchain_core.tracers.schemas import Run

LOGGER = logging.getLogger(__name__)


def _get_span_type(run: "Run") -> Literal["llm", "tool", "general"]:
if run.run_type in ["llm", "tool"]:
Expand All @@ -29,6 +32,7 @@ def __init__(
self,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
project_name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""
Expand All @@ -37,6 +41,7 @@ def __init__(
Args:
tags: List of tags to be applied to each trace logged by the tracer.
metadata: Additional metadata for each trace logged by the tracer.
project_name: The name of the project to log data.
"""
super().__init__(**kwargs)
self._trace_default_metadata = metadata if metadata is not None else {}
Expand All @@ -52,7 +57,11 @@ def __init__(

self._externally_created_traces_ids: Set[str] = set()

self._opik_client = opik_client.get_client_cached()
self._project_name = project_name

self._opik_client = opik_client.Opik(
_use_batching=True, project_name=project_name
)

def _persist_run(self, run: "Run") -> None:
run_dict: Dict[str, Any] = run.dict()
Expand All @@ -74,13 +83,17 @@ def _process_start_trace(self, run: "Run") -> None:
self._track_root_run(run_dict)
else:
parent_span_data = self._span_data_map[run.parent_run_id]

project_name = self._get_project_name(parent_span_data)

span_data = span.SpanData(
trace_id=parent_span_data.trace_id,
parent_span_id=parent_span_data.id,
input=run_dict["inputs"],
metadata=run_dict["extra"],
name=run.name,
type=_get_span_type(run),
project_name=project_name,
)

self._span_data_map[run.id] = span_data
Expand All @@ -90,6 +103,24 @@ def _process_start_trace(self, run: "Run") -> None:
run.parent_run_id
]

def _get_project_name(
self, parent_data: Union[trace.TraceData, span.SpanData]
) -> Optional[str]:
if parent_data.project_name != self._project_name:
# if the user has specified a project name -> print warning
if self._project_name is not None:
LOGGER.warning(
"You are attempting to log data into a nested span under "
f'the project name "{self._project_name}". '
f'However, the project name "{parent_data.project_name}" '
"from parent span will be used instead."
)
project_name = parent_data.project_name
else:
project_name = self._project_name

return project_name

def _track_root_run(self, run_dict: Dict[str, Any]) -> None:
run_metadata = run_dict["extra"].get("metadata", {})
root_metadata = dict_utils.deepmerge(self._trace_default_metadata, run_metadata)
Expand Down Expand Up @@ -123,6 +154,7 @@ def _initialize_span_and_trace_from_scratch(
input=run_dict["inputs"],
metadata=root_metadata,
tags=self._trace_default_tags,
project_name=self._project_name,
)

self._created_traces_data_map[run_dict["id"]] = trace_data
Expand All @@ -134,6 +166,7 @@ def _initialize_span_and_trace_from_scratch(
input=run_dict["inputs"],
metadata=root_metadata,
tags=self._trace_default_tags,
project_name=self._project_name,
)

self._span_data_map[run_dict["id"]] = span_
Expand All @@ -144,13 +177,16 @@ def _attach_span_to_existing_span(
current_span_data: span.SpanData,
root_metadata: Dict[str, Any],
) -> None:
project_name = self._get_project_name(current_span_data)

span_data = span.SpanData(
trace_id=current_span_data.trace_id,
parent_span_id=current_span_data.id,
name=run_dict["name"],
input=run_dict["inputs"],
metadata=root_metadata,
tags=self._trace_default_tags,
project_name=project_name,
)
self._span_data_map[run_dict["id"]] = span_data
self._externally_created_traces_ids.add(span_data.trace_id)
Expand All @@ -161,13 +197,16 @@ def _attach_span_to_existing_trace(
current_trace_data: trace.TraceData,
root_metadata: Dict[str, Any],
) -> None:
project_name = self._get_project_name(current_trace_data)

span_data = span.SpanData(
trace_id=current_trace_data.id,
parent_span_id=None,
name=run_dict["name"],
input=run_dict["inputs"],
metadata=root_metadata,
tags=self._trace_default_tags,
project_name=project_name,
)
self._span_data_map[run_dict["id"]] = span_data
self._externally_created_traces_ids.add(current_trace_data.id)
Expand Down
41 changes: 35 additions & 6 deletions sdks/python/tests/library_integration/langchain/test_langchain.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import mock
import os

from opik.config import OPIK_PROJECT_DEFAULT_NAME
from opik.message_processing import streamer_constructors
from ...testlib import backend_emulator_message_processor
from ...testlib import (
Expand Down Expand Up @@ -28,8 +30,17 @@ def ensure_openai_configured():
raise Exception("OpenAI not configured!")


@pytest.mark.parametrize(
"project_name, expected_project_name",
[
(None, OPIK_PROJECT_DEFAULT_NAME),
("langchain-integration-test", "langchain-integration-test"),
],
)
def test_langchain__happyflow(
fake_streamer,
project_name,
expected_project_name,
):
fake_message_processor_: (
backend_emulator_message_processor.BackendEmulatorMessageProcessor
Expand Down Expand Up @@ -57,7 +68,9 @@ def test_langchain__happyflow(
synopsis_chain = prompt_template | llm
test_prompts = {"title": "Documentary about Bigfoot in Paris"}

callback = OpikTracer(tags=["tag1", "tag2"], metadata={"a": "b"})
callback = OpikTracer(
project_name=project_name, tags=["tag1", "tag2"], metadata={"a": "b"}
)
synopsis_chain.invoke(input=test_prompts, config={"callbacks": [callback]})

callback.flush()
Expand All @@ -74,6 +87,7 @@ def test_langchain__happyflow(
metadata={"a": "b"},
start_time=ANY_BUT_NONE,
end_time=ANY_BUT_NONE,
project_name=expected_project_name,
spans=[
SpanModel(
id=ANY_BUT_NONE,
Expand All @@ -84,6 +98,7 @@ def test_langchain__happyflow(
metadata={"a": "b"},
start_time=ANY_BUT_NONE,
end_time=ANY_BUT_NONE,
project_name=expected_project_name,
spans=[
SpanModel(
id=ANY_BUT_NONE,
Expand All @@ -94,6 +109,7 @@ def test_langchain__happyflow(
metadata={},
start_time=ANY_BUT_NONE,
end_time=ANY_BUT_NONE,
project_name=expected_project_name,
spans=[],
),
SpanModel(
Expand All @@ -120,6 +136,7 @@ def test_langchain__happyflow(
},
start_time=ANY_BUT_NONE,
end_time=ANY_BUT_NONE,
project_name=expected_project_name,
spans=[],
),
],
Expand Down Expand Up @@ -240,9 +257,16 @@ def test_langchain_callback__used_inside_another_track_function__data_attached_t
"construct_online_streamer",
mock_construct_online_streamer,
):
callback = OpikTracer(tags=["tag1", "tag2"], metadata={"a": "b"})
project_name = "langchain-integration-test"

callback = OpikTracer(
# we are trying to log span into another project, but parent's project name will be used
project_name="langchain-integration-test-nested-level",
tags=["tag1", "tag2"],
metadata={"a": "b"},
)

@opik.track(capture_output=True)
@opik.track(project_name=project_name, capture_output=True)
def f(x):
llm = fake.FakeListLLM(
responses=[
Expand All @@ -268,7 +292,7 @@ def f(x):
f("the-input")
opik.flush_tracker()

mock_construct_online_streamer.assert_called_once()
mock_construct_online_streamer.assert_called()

EXPECTED_TRACE_TREE = TraceModel(
id=ANY_BUT_NONE,
Expand All @@ -277,6 +301,7 @@ def f(x):
output={"output": "the-output"},
start_time=ANY_BUT_NONE,
end_time=ANY_BUT_NONE,
project_name=project_name,
spans=[
SpanModel(
id=ANY_BUT_NONE,
Expand All @@ -285,6 +310,7 @@ def f(x):
output={"output": "the-output"},
start_time=ANY_BUT_NONE,
end_time=ANY_BUT_NONE,
project_name=project_name,
spans=[
SpanModel(
id=ANY_BUT_NONE,
Expand All @@ -297,6 +323,7 @@ def f(x):
metadata={"a": "b"},
start_time=ANY_BUT_NONE,
end_time=ANY_BUT_NONE,
project_name=project_name,
spans=[
SpanModel(
id=ANY_BUT_NONE,
Expand All @@ -309,6 +336,7 @@ def f(x):
metadata={},
start_time=ANY_BUT_NONE,
end_time=ANY_BUT_NONE,
project_name=project_name,
spans=[],
),
SpanModel(
Expand All @@ -335,6 +363,7 @@ def f(x):
},
start_time=ANY_BUT_NONE,
end_time=ANY_BUT_NONE,
project_name=project_name,
spans=[],
),
],
Expand Down Expand Up @@ -407,7 +436,7 @@ def f():

opik.flush_tracker()

mock_construct_online_streamer.assert_called_once()
mock_construct_online_streamer.assert_called()

EXPECTED_TRACE_TREE = TraceModel(
id=ANY_BUT_NONE,
Expand Down Expand Up @@ -532,7 +561,7 @@ def f():
client.span(**span_data.__dict__)
opik.flush_tracker()

mock_construct_online_streamer.assert_called_once()
mock_construct_online_streamer.assert_called()

EXPECTED_SPANS_TREE = SpanModel(
id=ANY_BUT_NONE,
Expand Down

0 comments on commit 4330b09

Please sign in to comment.