Skip to content

Commit

Permalink
feat: use async lc tracer instead of run_sync (#1529)
Browse files Browse the repository at this point in the history
Co-authored-by: Mathijs de Bruin <[email protected]>
  • Loading branch information
willydouhard and dokterbob authored Nov 19, 2024
1 parent 2bd47f5 commit ff26451
Showing 1 changed file with 55 additions and 46 deletions.
101 changes: 55 additions & 46 deletions backend/chainlit/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
from uuid import UUID

from chainlit.context import context_var
from chainlit.message import Message
from chainlit.step import Step
from langchain.callbacks.tracers.base import BaseTracer
import pydantic
from langchain.callbacks.tracers.schemas import Run
from langchain.schema import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain_core.tracers.base import AsyncBaseTracer
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
from literalai.helper import utc_now
from literalai.observability.step import TrueStepType

from chainlit.context import context_var
from chainlit.message import Message
from chainlit.step import Step

DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]


Expand Down Expand Up @@ -122,6 +124,14 @@ def ensure_values_serializable(self, data):
key: self.ensure_values_serializable(value)
for key, value in data.items()
}
elif isinstance(data, pydantic.BaseModel):
# Fallback to support pydantic v1
# https://docs.pydantic.dev/latest/migration/#changes-to-pydanticbasemodel
if pydantic.VERSION.startswith("1"):
return data.dict()

# pydantic v2
return data.model_dump() # pyright: ignore reportAttributeAccessIssue
elif isinstance(data, list):
return [self.ensure_values_serializable(item) for item in data]
elif isinstance(data, (str, int, float, bool, type(None))):
Expand Down Expand Up @@ -249,7 +259,7 @@ def process_content(content: Any) -> Tuple[Dict, Optional[str]]:
DEFAULT_TO_KEEP = ["retriever", "llm", "agent", "chain", "tool"]


class LangchainTracer(BaseTracer, GenerationHelper, FinalStreamHelper):
class LangchainTracer(AsyncBaseTracer, GenerationHelper, FinalStreamHelper):
steps: Dict[str, Step]
parent_id_map: Dict[str, str]
ignored_runs: set
Expand All @@ -268,7 +278,7 @@ def __init__(
to_keep: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
BaseTracer.__init__(self, **kwargs)
AsyncBaseTracer.__init__(self, **kwargs)
GenerationHelper.__init__(self)
FinalStreamHelper.__init__(
self,
Expand Down Expand Up @@ -296,7 +306,7 @@ def __init__(
else:
self.to_keep = to_keep

def on_chat_model_start(
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
Expand All @@ -305,8 +315,9 @@ def on_chat_model_start(
parent_run_id: Optional["UUID"] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Any:
) -> Run:
lc_messages = messages[0]
self.chat_generations[str(run_id)] = {
"input_messages": lc_messages,
Expand All @@ -315,54 +326,63 @@ def on_chat_model_start(
"tt_first_token": None,
}

return super().on_chat_model_start(
return await super().on_chat_model_start(
serialized,
messages,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
name=name,
**kwargs,
)

def on_llm_start(
async def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: "UUID",
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
parent_run_id: Optional["UUID"] = None,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Run:
self.completion_generations[str(run_id)] = {
"prompt": prompts[0],
"start": time.time(),
"token_count": 0,
"tt_first_token": None,
}
return super().on_llm_start(
) -> None:
await super().on_llm_start(
serialized,
prompts,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
name=name,
**kwargs,
)

def on_llm_new_token(
self.completion_generations[str(run_id)] = {
"prompt": prompts[0],
"start": time.time(),
"token_count": 0,
"tt_first_token": None,
}

return None

async def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: "UUID",
parent_run_id: Optional["UUID"] = None,
**kwargs: Any,
) -> Run:
) -> None:
await super().on_llm_new_token(
token=token,
chunk=chunk,
run_id=run_id,
parent_run_id=parent_run_id,
**kwargs,
)
if isinstance(chunk, ChatGenerationChunk):
start = self.chat_generations[str(run_id)]
else:
Expand All @@ -377,24 +397,13 @@ def on_llm_new_token(
if self.answer_reached:
if not self.final_stream:
self.final_stream = Message(content="")
self._run_sync(self.final_stream.send())
self._run_sync(self.final_stream.stream_token(token))
await self.final_stream.send()
await self.final_stream.stream_token(token)
self.has_streamed_final_answer = True
else:
self.answer_reached = self._check_if_answer_reached()

return super().on_llm_new_token(
token,
chunk=chunk,
run_id=run_id,
parent_run_id=parent_run_id,
)

def _run_sync(self, co): # TODO: WHAT TO DO WITH THIS?
context_var.set(self.context)
self.context.loop.create_task(co)

def _persist_run(self, run: Run) -> None:
async def _persist_run(self, run: Run) -> None:
pass

def _get_run_parent_id(self, run: Run):
Expand Down Expand Up @@ -445,8 +454,8 @@ def _should_ignore_run(self, run: Run):
self.ignored_runs.add(str(run.id))
return ignore, parent_id

def _start_trace(self, run: Run) -> None:
super()._start_trace(run)
async def _start_trace(self, run: Run) -> None:
await super()._start_trace(run)
context_var.set(self.context)

ignore, parent_id = self._should_ignore_run(run)
Expand Down Expand Up @@ -489,9 +498,9 @@ def _start_trace(self, run: Run) -> None:

self.steps[str(run.id)] = step

self._run_sync(step.send())
await step.send()

def _on_run_update(self, run: Run) -> None:
async def _on_run_update(self, run: Run) -> None:
"""Process a run upon update."""
context_var.set(self.context)

Expand Down Expand Up @@ -576,10 +585,10 @@ def _on_run_update(self, run: Run) -> None:

if current_step:
current_step.end = utc_now()
self._run_sync(current_step.update())
await current_step.update()

if self.final_stream and self.has_streamed_final_answer:
self._run_sync(self.final_stream.update())
await self.final_stream.update()

return

Expand All @@ -599,16 +608,16 @@ def _on_run_update(self, run: Run) -> None:
else output
)
current_step.end = utc_now()
self._run_sync(current_step.update())
await current_step.update()

def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
async def _on_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any):
context_var.set(self.context)

if current_step := self.steps.get(str(run_id), None):
current_step.is_error = True
current_step.output = str(error)
current_step.end = utc_now()
self._run_sync(current_step.update())
await current_step.update()

on_llm_error = _on_error
on_chain_error = _on_error
Expand Down

0 comments on commit ff26451

Please sign in to comment.