Skip to content

Commit

Permalink
py tracer fixes (langchain-ai#5377)
Browse files Browse the repository at this point in the history
  • Loading branch information
agola11 authored and Undertone0809 committed Jun 19, 2023
1 parent b08a832 commit 3599b2e
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 167 deletions.
12 changes: 6 additions & 6 deletions docs/tracing/agent_with_tracing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 7,
"id": "87027b0d-3a61-47cf-8a65-3002968be7f9",
"metadata": {
"tags": []
Expand All @@ -356,13 +356,13 @@
"source": [
"import os\n",
"os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://langchainpro-api-gateway-12bfv6cf.uc.gateway.dev\" # Uncomment this line if you want to use the hosted version\n",
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.langchain.plus\" # Uncomment this line if you want to use the hosted version\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = \"<YOUR-LANGCHAINPLUS-API-KEY>\" # Uncomment this line if you want to use the hosted version."
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 8,
"id": "5b4f49a2-7d09-4601-a8ba-976f0517c64c",
"metadata": {
"tags": []
Expand All @@ -379,7 +379,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 9,
"id": "029b4a57-dc49-49de-8f03-53c292144e09",
"metadata": {
"tags": []
Expand All @@ -397,7 +397,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 10,
"id": "91a85fb2-6027-4bd0-b1fe-2a3b3b79e2dd",
"metadata": {
"tags": []
Expand Down Expand Up @@ -426,7 +426,7 @@
"'1.0891804557407723'"
]
},
"execution_count": 15,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
184 changes: 160 additions & 24 deletions langchain/callbacks/tracers/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,35 @@

import logging
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any, Dict, List, Optional
from uuid import UUID

import requests
from tenacity import retry, stop_after_attempt, wait_fixed
from requests.exceptions import HTTPError
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)

from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import (
Run,
RunCreate,
RunTypeEnum,
RunUpdate,
TracerSession,
TracerSessionCreate,
)
from langchain.schema import BaseMessage, messages_to_dict
from langchain.utils import raise_for_status_with_text

logger = logging.getLogger(__name__)


def get_headers() -> Dict[str, Any]:
"""Get the headers for the LangChain API."""
Expand All @@ -34,7 +45,27 @@ def get_endpoint() -> str:
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")


@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
class LangChainTracerAPIError(Exception):
"""An error occurred while communicating with the LangChain API."""


class LangChainTracerUserError(Exception):
"""An error occurred while communicating with the LangChain API."""


class LangChainTracerError(Exception):
"""An error occurred while communicating with the LangChain API."""


retry_decorator = retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(LangChainTracerAPIError),
before_sleep=before_sleep_log(logger, logging.WARNING),
)


@retry_decorator
def _get_tenant_id(
tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict]
) -> str:
Expand All @@ -44,8 +75,24 @@ def _get_tenant_id(
return tenant_id_
endpoint_ = endpoint or get_endpoint()
headers_ = headers or get_headers()
response = requests.get(endpoint_ + "/tenants", headers=headers_)
raise_for_status_with_text(response)
response = None
try:
response = requests.get(endpoint_ + "/tenants", headers=headers_)
raise_for_status_with_text(response)
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to get tenant ID from LangChain API. {e}"
)
else:
raise LangChainTracerUserError(
f"Failed to get tenant ID from LangChain API. {e}"
)
except Exception as e:
raise LangChainTracerError(
f"Failed to get tenant ID from LangChain API. {e}"
) from e

tenants: List[Dict[str, Any]] = response.json()
if not tenants:
raise ValueError(f"No tenants found for URL {endpoint_}")
Expand All @@ -72,6 +119,8 @@ def __init__(
self.example_id = example_id
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
self.session_extra = session_extra
# set max_workers to 1 to process tasks in order
self.executor = ThreadPoolExecutor(max_workers=1)

def on_chat_model_start(
self,
Expand Down Expand Up @@ -108,7 +157,7 @@ def ensure_tenant_id(self) -> str:
self.tenant_id = tenant_id
return tenant_id

@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
@retry_decorator
def ensure_session(self) -> TracerSession:
"""Upsert a session."""
if self.session is not None:
Expand All @@ -118,37 +167,124 @@ def ensure_session(self) -> TracerSession:
session_create = TracerSessionCreate(
name=self.session_name, extra=self.session_extra, tenant_id=tenant_id
)
r = requests.post(
url,
data=session_create.json(),
headers=self._headers,
)
raise_for_status_with_text(r)
self.session = TracerSession(**r.json())
response = None
try:
response = requests.post(
url,
data=session_create.json(),
headers=self._headers,
)
response.raise_for_status()
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to upsert session to LangChain API. {e}"
)
else:
raise LangChainTracerUserError(
f"Failed to upsert session to LangChain API. {e}"
)
except Exception as e:
raise LangChainTracerError(
f"Failed to upsert session to LangChain API. {e}"
) from e
self.session = TracerSession(**response.json())
return self.session

def _persist_run_nested(self, run: Run) -> None:
def _persist_run(self, run: Run) -> None:
"""Persist a run."""

@retry_decorator
def _persist_run_single(self, run: Run) -> None:
"""Persist a run."""
session = self.ensure_session()
child_runs = run.child_runs
if run.parent_run_id is None:
run.reference_example_id = self.example_id
run_dict = run.dict()
del run_dict["child_runs"]
run_create = RunCreate(**run_dict, session_id=session.id)
response = None
try:
response = requests.post(
f"{self._endpoint}/runs",
data=run_create.json(),
headers=self._headers,
)
raise_for_status_with_text(response)
response.raise_for_status()
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to upsert persist run to LangChain API. {e}"
)
else:
raise LangChainTracerUserError(
f"Failed to persist run to LangChain API. {e}"
)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")
for child_run in child_runs:
child_run.parent_run_id = run.id
self._persist_run_nested(child_run)
raise LangChainTracerError(
f"Failed to persist run to LangChain API. {e}"
) from e

def _persist_run(self, run: Run) -> None:
"""Persist a run."""
run.reference_example_id = self.example_id
# TODO: Post first then patch
self._persist_run_nested(run)
@retry_decorator
def _update_run_single(self, run: Run) -> None:
"""Update a run."""
run_update = RunUpdate(**run.dict())
response = None
try:
response = requests.patch(
f"{self._endpoint}/runs/{run.id}",
data=run_update.json(),
headers=self._headers,
)
response.raise_for_status()
except HTTPError as e:
if response is not None and response.status_code == 500:
raise LangChainTracerAPIError(
f"Failed to update run to LangChain API. {e}"
)
else:
raise LangChainTracerUserError(f"Failed to run to LangChain API. {e}")
except Exception as e:
raise LangChainTracerError(
f"Failed to update run to LangChain API. {e}"
) from e

def _on_llm_start(self, run: Run) -> None:
"""Persist an LLM run."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))

def _on_chat_model_start(self, run: Run) -> None:
"""Persist an LLM run."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))

def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))

def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))

def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))

def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))

def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))

def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start."""
self.executor.submit(self._persist_run_single, run.copy(deep=True))

def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run."""
self.executor.submit(self._update_run_single, run.copy(deep=True))

def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error."""
self.executor.submit(self._update_run_single, run.copy(deep=True))
13 changes: 12 additions & 1 deletion langchain/callbacks/tracers/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class ToolRun(BaseRun):
child_tool_runs: List[ToolRun] = Field(default_factory=list)


# Begin V2 API Schemas


class RunTypeEnum(str, Enum):
"""Enum for run types."""

Expand All @@ -105,7 +108,7 @@ class RunBase(BaseModel):
id: Optional[UUID]
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
extra: dict
extra: Optional[Dict[str, Any]] = None
error: Optional[str]
execution_order: int
child_execution_order: Optional[int]
Expand Down Expand Up @@ -144,5 +147,13 @@ def add_runtime_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
return values


class RunUpdate(BaseModel):
end_time: Optional[datetime.datetime]
error: Optional[str]
outputs: Optional[dict]
parent_run_id: Optional[UUID]
reference_example_id: Optional[UUID]


ChainRun.update_forward_refs()
ToolRun.update_forward_refs()
5 changes: 3 additions & 2 deletions tests/integration_tests/callbacks/test_langchain_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.callbacks import tracing_enabled
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI

questions = [
Expand Down Expand Up @@ -140,10 +141,10 @@ async def test_tracing_v2_environment_variable() -> None:


def test_tracing_v2_context_manager() -> None:
llm = OpenAI(temperature=0)
llm = ChatOpenAI(temperature=0)
tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
if "LANGCHAIN_TRACING_V2" in os.environ:
del os.environ["LANGCHAIN_TRACING_V2"]
Expand Down
Loading

0 comments on commit 3599b2e

Please sign in to comment.