Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

py tracer fixes #5377

Merged
merged 8 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 160 additions & 27 deletions langchain/callbacks/tracers/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,35 @@

import logging
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any, Dict, List, Optional
from uuid import UUID

import requests
from tenacity import retry, stop_after_attempt, wait_fixed
from requests.exceptions import HTTPError
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import (
Run,
RunCreate,
RunTypeEnum,
RunUpdate,
TracerSession,
TracerSessionCreate,
)
from langchain.schema import BaseMessage, messages_to_dict
from langchain.utils import raise_for_status_with_text

logger = logging.getLogger(__name__)


def get_headers() -> Dict[str, Any]:
"""Get the headers for the LangChain API."""
Expand All @@ -34,7 +45,27 @@ def get_endpoint() -> str:
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")


@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
class LangChainTracerAPIError(Exception):
"""An error occurred while communicating with the LangChain API."""


class LangChainTracerUserError(Exception):
"""An error occurred while communicating with the LangChain API."""


class LangChainTracerError(Exception):
"""An error occurred while communicating with the LangChain API."""


retry_decorator = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(LangChainTracerAPIError),
before_sleep=before_sleep_log(logger, logging.WARNING),
)


@retry_decorator
def _get_tenant_id(
tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict]
) -> str:
Expand All @@ -44,8 +75,24 @@ def _get_tenant_id(
return tenant_id_
endpoint_ = endpoint or get_endpoint()
headers_ = headers or get_headers()
response = requests.get(endpoint_ + "/tenants", headers=headers_)
raise_for_status_with_text(response)
response = None
try:
response = requests.get(endpoint_ + "/tenants", headers=headers_)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be removed, or is that another PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise_for_status_with_text(response)
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to get tenant ID from LangChain API. {e}"
)
else:
raise LangChainTracerUserError(
f"Failed to get tenant ID from LangChain API. {e}"
)
except Exception as e:
raise LangChainTracerError(
f"Failed to get tenant ID from LangChain API. {e}"
) from e

tenants: List[Dict[str, Any]] = response.json()
if not tenants:
raise ValueError(f"No tenants found for URL {endpoint_}")
Expand All @@ -72,6 +119,8 @@ def __init__(
self.example_id = example_id
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
self.session_extra = session_extra
# set max_workers to 1 to process tasks in order
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does order matter? Can we rely on every user of the api being as strict about order as we are being here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll fail if a child run gets posted before a parent run at least, right? (parent_run_id not found)

self.executor = ThreadPoolExecutor(max_workers=1)

def on_chat_model_start(
self,
Expand Down Expand Up @@ -108,7 +157,7 @@ def ensure_tenant_id(self) -> str:
self.tenant_id = tenant_id
return tenant_id

@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
@retry_decorator
def ensure_session(self) -> TracerSession:
"""Upsert a session."""
if self.session is not None:
Expand All @@ -118,37 +167,121 @@ def ensure_session(self) -> TracerSession:
session_create = TracerSessionCreate(
name=self.session_name, extra=self.session_extra, tenant_id=tenant_id
)
r = requests.post(
url,
data=session_create.json(),
headers=self._headers,
)
raise_for_status_with_text(r)
self.session = TracerSession(**r.json())
response = None
try:
response = requests.post(
url,
data=session_create.json(),
headers=self._headers,
)
response.raise_for_status()
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to upsert session to LangChain API. {e}"
)
else:
raise LangChainTracerUserError(
f"Failed to upsert session to LangChain API. {e}"
)
except Exception as e:
raise LangChainTracerError(
f"Failed to upsert session to LangChain API. {e}"
) from e
self.session = TracerSession(**response.json())
return self.session

def _persist_run_nested(self, run: Run) -> None:
def _persist_run(self, run: Run) -> None:
"""Persist a run."""

@retry_decorator
def _persist_run_single(self, run: Run) -> None:
"""Persist a run."""
session = self.ensure_session()
child_runs = run.child_runs
run_dict = run.dict()
del run_dict["child_runs"]
run_create = RunCreate(**run_dict, session_id=session.id)
run.reference_example_id = self.example_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this, (I think) we should be only assigning if there is no parent_run_id, otherwise it links all the nested child runs to an example

run_create = RunCreate(**run.dict(), session_id=session.id)
response = None
try:
response = requests.post(
f"{self._endpoint}/runs",
data=run_create.json(),
headers=self._headers,
)
raise_for_status_with_text(response)
response.raise_for_status()
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to upsert persist run to LangChain API. {e}"
)
else:
raise LangChainTracerUserError(
f"Failed to persist run to LangChain API. {e}"
)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")
for child_run in child_runs:
child_run.parent_run_id = run.id
self._persist_run_nested(child_run)
raise LangChainTracerError(
f"Failed to persist run to LangChain API. {e}"
) from e

def _persist_run(self, run: Run) -> None:
"""Persist a run."""
run.reference_example_id = self.example_id
# TODO: Post first then patch
self._persist_run_nested(run)
@retry_decorator
def _update_run_single(self, run: Run) -> None:
"""Update a run."""
run_update = RunUpdate(**run.dict())
response = None
try:
response = requests.patch(
f"{self._endpoint}/runs/{run.id}",
data=run_update.json(),
headers=self._headers,
)
response.raise_for_status()
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to update run to LangChain API. {e}"
)
else:
raise LangChainTracerUserError(f"Failed to run to LangChain API. {e}")
except Exception as e:
raise LangChainTracerError(
f"Failed to update run to LangChain API. {e}"
) from e

def _on_llm_start(self, run: Run) -> None:
"""Persist an LLM run."""
self.executor.submit(self._persist_run_single, run)

def _on_chat_model_start(self, run: Run) -> None:
"""Persist an LLM run."""
self.executor.submit(self._persist_run_single, run)

def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
self.executor.submit(self._update_run_single, run)

def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
self.executor.submit(self._update_run_single, run)

def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start."""
self.executor.submit(self._persist_run_single, run)

def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run."""
self.executor.submit(self._update_run_single, run)

def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error."""
self.executor.submit(self._update_run_single, run)

def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start."""
self.executor.submit(self._persist_run_single, run)

def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run."""
self.executor.submit(self._update_run_single, run)

def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
self.executor.submit(self._update_run_single, run)
13 changes: 12 additions & 1 deletion langchain/callbacks/tracers/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class ToolRun(BaseRun):
child_tool_runs: List[ToolRun] = Field(default_factory=list)


# Begin V2 API Schemas


class RunTypeEnum(str, Enum):
"""Enum for run types."""

Expand All @@ -105,7 +108,7 @@ class RunBase(BaseModel):
id: Optional[UUID]
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
extra: dict
extra: Optional[Dict[str, Any]] = None
error: Optional[str]
execution_order: int
child_execution_order: Optional[int]
Expand Down Expand Up @@ -144,5 +147,13 @@ def add_runtime_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values


class RunUpdate(BaseModel):
end_time: Optional[datetime.datetime]
error: Optional[str]
outputs: Optional[dict]
parent_run_id: Optional[UUID]
reference_example_id: Optional[UUID]


ChainRun.update_forward_refs()
ToolRun.update_forward_refs()
5 changes: 3 additions & 2 deletions tests/integration_tests/callbacks/test_langchain_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.callbacks import tracing_enabled
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI

questions = [
Expand Down Expand Up @@ -140,10 +141,10 @@ async def test_tracing_v2_environment_variable() -> None:


def test_tracing_v2_context_manager() -> None:
llm = OpenAI(temperature=0)
llm = ChatOpenAI(temperature=0)
tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_TRACING_V2" in os.environ:
del os.environ["LANGCHAIN_TRACING_V2"]
Expand Down