From 60cd178275b5ce671ef35147802e237b823373a7 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Mon, 18 Nov 2024 11:52:03 -0800 Subject: [PATCH 1/2] feat: use async lc tracer instead of run_sync --- backend/chainlit/langchain/callbacks.py | 95 +++++++++++++------------ 1 file changed, 49 insertions(+), 46 deletions(-) diff --git a/backend/chainlit/langchain/callbacks.py b/backend/chainlit/langchain/callbacks.py index 147f9de413..76960d1364 100644 --- a/backend/chainlit/langchain/callbacks.py +++ b/backend/chainlit/langchain/callbacks.py @@ -3,16 +3,18 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from uuid import UUID -from chainlit.context import context_var -from chainlit.message import Message -from chainlit.step import Step -from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import Run from langchain.schema import BaseMessage from langchain_core.outputs import ChatGenerationChunk, GenerationChunk +from langchain_core.tracers.base import AsyncBaseTracer from literalai import ChatGeneration, CompletionGeneration, GenerationMessage from literalai.helper import utc_now from literalai.observability.step import TrueStepType +from pydantic import BaseModel + +from chainlit.context import context_var +from chainlit.message import Message +from chainlit.step import Step DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"] @@ -122,6 +124,8 @@ def ensure_values_serializable(self, data): key: self.ensure_values_serializable(value) for key, value in data.items() } + elif isinstance(data, BaseModel): + return data.model_dump() elif isinstance(data, list): return [self.ensure_values_serializable(item) for item in data] elif isinstance(data, (str, int, float, bool, type(None))): @@ -249,7 +253,7 @@ def process_content(content: Any) -> Tuple[Dict, Optional[str]]: DEFAULT_TO_KEEP = ["retriever", "llm", "agent", "chain", "tool"] -class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper): +class LangchainTracer(AsyncBaseTracer, GenerationHelper, FinalStreamHelper): steps: Dict[str, Step] parent_id_map: Dict[str, str] ignored_runs: set @@ -268,7 +272,7 @@ def __init__( to_keep: Optional[List[str]] = None, **kwargs: Any, ) -> None: - BaseTracer.__init__(self, **kwargs) + AsyncBaseTracer.__init__(self, **kwargs) GenerationHelper.__init__(self) FinalStreamHelper.__init__( self, @@ -296,7 +300,7 @@ def __init__( else: self.to_keep = to_keep - def on_chat_model_start( + async def on_chat_model_start( self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], @@ -305,8 +309,9 @@ def on_chat_model_start( parent_run_id: Optional["UUID"] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + name: Optional[str] = None, **kwargs: Any, - ) -> Any: + ) -> Run: lc_messages = messages[0] self.chat_generations[str(run_id)] = { "input_messages": lc_messages, @@ -315,46 +320,48 @@ def on_chat_model_start( "tt_first_token": None, } - return super().on_chat_model_start( + return await super().on_chat_model_start( serialized, messages, run_id=run_id, parent_run_id=parent_run_id, tags=tags, metadata=metadata, + name=name, **kwargs, ) - def on_llm_start( + async def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], *, run_id: "UUID", + parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, - parent_run_id: Optional["UUID"] = None, metadata: Optional[Dict[str, Any]] = None, - name: Optional[str] = None, **kwargs: Any, - ) -> Run: - self.completion_generations[str(run_id)] = { - "prompt": prompts[0], - "start": time.time(), - "token_count": 0, - "tt_first_token": None, - } - return super().on_llm_start( + ) -> None: + await super().on_llm_start( serialized, prompts, run_id=run_id, parent_run_id=parent_run_id, tags=tags, metadata=metadata, - name=name, **kwargs, ) - def on_llm_new_token( + self.completion_generations[str(run_id)] = { + "prompt": prompts[0], + "start": time.time(), + "token_count": 0, + "tt_first_token": None, + } + + return None + + async def on_llm_new_token( self, token: str, *, @@ -362,7 +369,14 @@ def on_llm_new_token( run_id: "UUID", parent_run_id: Optional["UUID"] = None, **kwargs: Any, - ) -> Run: + ) -> None: + await super().on_llm_new_token( + token=token, + chunk=chunk, + run_id=run_id, + parent_run_id=parent_run_id, + **kwargs, + ) if isinstance(chunk, ChatGenerationChunk): start = self.chat_generations[str(run_id)] else: @@ -377,24 +391,13 @@ def on_llm_new_token( if self.answer_reached: if not self.final_stream: self.final_stream = Message(content="") - self._run_sync(self.final_stream.send()) - self._run_sync(self.final_stream.stream_token(token)) + await self.final_stream.send() + await self.final_stream.stream_token(token) self.has_streamed_final_answer = True else: self.answer_reached = self._check_if_answer_reached() - return super().on_llm_new_token( - token, - chunk=chunk, - run_id=run_id, - parent_run_id=parent_run_id, - ) - - def _run_sync(self, co): # TODO: WHAT TO DO WITH THIS? - context_var.set(self.context) - self.context.loop.create_task(co) - - def _persist_run(self, run: Run) -> None: + async def _persist_run(self, run: Run) -> None: pass def _get_run_parent_id(self, run: Run): @@ -445,8 +448,8 @@ def _should_ignore_run(self, run: Run): self.ignored_runs.add(str(run.id)) return ignore, parent_id - def _start_trace(self, run: Run) -> None: - super()._start_trace(run) + async def _start_trace(self, run: Run) -> None: + await super()._start_trace(run) context_var.set(self.context) ignore, parent_id = self._should_ignore_run(run) @@ -489,9 +492,9 @@ def _start_trace(self, run: Run) -> None: self.steps[str(run.id)] = step - self._run_sync(step.send()) + await step.send() - def _on_run_update(self, run: Run) -> None: + async def _on_run_update(self, run: Run) -> None: """Process a run upon update.""" context_var.set(self.context) @@ -576,10 +579,10 @@ def _on_run_update(self, run: Run) -> None: if current_step: current_step.end = utc_now() - self._run_sync(current_step.update()) + await current_step.update() if self.final_stream and self.has_streamed_final_answer: - self._run_sync(self.final_stream.update()) + await self.final_stream.update() return @@ -599,16 +602,16 @@ def _on_run_update(self, run: Run) -> None: else output ) current_step.end = utc_now() - self._run_sync(current_step.update()) + await current_step.update() - def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any): + async def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any): context_var.set(self.context) if current_step := self.steps.get(str(run_id), None): current_step.is_error = True current_step.output = str(error) current_step.end = utc_now() - self._run_sync(current_step.update()) + await current_step.update() on_llm_error = _on_error on_chain_error = _on_error From 3cc6ace99e1783abefa99f01f20803c17399bc4f Mon Sep 17 00:00:00 2001 From: Mathijs de Bruin Date: Tue, 19 Nov 2024 11:34:54 +0000 Subject: [PATCH 2/2] Add fallback for pydantic v1. --- backend/chainlit/langchain/callbacks.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/backend/chainlit/langchain/callbacks.py b/backend/chainlit/langchain/callbacks.py index 76960d1364..7449663dd4 100644 --- a/backend/chainlit/langchain/callbacks.py +++ b/backend/chainlit/langchain/callbacks.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from uuid import UUID +import pydantic from langchain.callbacks.tracers.schemas import Run from langchain.schema import BaseMessage from langchain_core.outputs import ChatGenerationChunk, GenerationChunk @@ -10,7 +11,6 @@ from literalai import ChatGeneration, CompletionGeneration, GenerationMessage from literalai.helper import utc_now from literalai.observability.step import TrueStepType -from pydantic import BaseModel from chainlit.context import context_var from chainlit.message import Message @@ -124,8 +124,14 @@ def ensure_values_serializable(self, data): key: self.ensure_values_serializable(value) for key, value in data.items() } - elif isinstance(data, BaseModel): - return data.model_dump() + elif isinstance(data, pydantic.BaseModel): + # Fallback to support pydantic v1 + # https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel + if pydantic.VERSION.startswith("1"): + return data.dict() + + # pydantic v2 + return data.model_dump() # pyright: ignore reportAttributeAccessIssue elif isinstance(data, list): return [self.ensure_values_serializable(item) for item in data] elif isinstance(data, (str, int, float, bool, type(None))):