Skip to content

Commit

Permalink
Bump LiteralAI to 0.0.625, refactor LiteralDataLayer (#1376)
Browse files Browse the repository at this point in the history
* Bump LiteralAI dependency, update related imports.
* Cleanup and organize imports from SQLAlchemy tests.
* Extensive unittest coverage  for LiteralDataLayer.
* Consistent LiteralAI to Chainlit conversion, resolve PaginatedResponse exceptions.
* LiteralToChainlitConverter class for handling conversions, methods for converting steps, threads, and attachments.
* Allow manual setting of thread_id and id for Step and Element

---------

Co-authored-by: EWouters <[email protected]>
  • Loading branch information
dokterbob and EWouters authored Oct 2, 2024
1 parent 11664b3 commit a0a8fa7
Show file tree
Hide file tree
Showing 11 changed files with 1,269 additions and 88 deletions.
227 changes: 168 additions & 59 deletions backend/chainlit/data/literalai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
import json
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, cast
from typing import Dict, List, Literal, Optional, Union, cast

import aiofiles
from httpx import HTTPStatusError, RequestError
from literalai import (
Attachment as LiteralAttachment,
Score as LiteralScore,
Step as LiteralStep,
Thread as LiteralThread,
)
from literalai.observability.filter import threads_filters as LiteralThreadsFilters
from literalai.observability.step import StepDict as LiteralStepDict

from chainlit.data.base import BaseDataLayer
from chainlit.data.utils import queue_until_user_message
from chainlit.element import Audio, Element, ElementDict, File, Image, Pdf, Text, Video
from chainlit.logger import logger
from chainlit.step import (
FeedbackDict,
Step,
StepDict,
StepType,
TrueStepType,
check_add_step_in_cot,
stub_step,
)
from chainlit.types import (
Feedback,
PageInfo,
Expand All @@ -14,50 +34,19 @@
ThreadFilter,
)
from chainlit.user import PersistedUser, User
from httpx import HTTPStatusError, RequestError
from literalai import Attachment
from literalai import Score as LiteralScore
from literalai import Step as LiteralStep
from literalai.filter import threads_filters as LiteralThreadsFilters
from literalai.step import StepDict as LiteralStepDict

if TYPE_CHECKING:
from chainlit.element import Element, ElementDict
from chainlit.step import FeedbackDict, StepDict


_data_layer: Optional[BaseDataLayer] = None


class LiteralDataLayer(BaseDataLayer):
def __init__(self, api_key: str, server: Optional[str]):
from literalai import AsyncLiteralClient
class LiteralToChainlitConverter:
@classmethod
def steptype_to_steptype(cls, step_type: Optional[StepType]) -> TrueStepType:
if step_type in ["user_message", "assistant_message", "system_message"]:
return "undefined"
return cast(TrueStepType, step_type or "undefined")

self.client = AsyncLiteralClient(api_key=api_key, url=server)
logger.info("Chainlit data layer initialized")

def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict":
metadata = attachment.metadata or {}
return {
"chainlitKey": None,
"display": metadata.get("display", "side"),
"language": metadata.get("language"),
"autoPlay": metadata.get("autoPlay", None),
"playerConfig": metadata.get("playerConfig", None),
"page": metadata.get("page"),
"size": metadata.get("size"),
"type": metadata.get("type", "file"),
"forId": attachment.step_id,
"id": attachment.id or "",
"mime": attachment.mime,
"name": attachment.name or "",
"objectKey": attachment.object_key,
"url": attachment.url,
"threadId": attachment.thread_id,
}

def score_to_feedback_dict(
self, score: Optional[LiteralScore]
@classmethod
def score_to_feedbackdict(
cls,
score: Optional[LiteralScore],
) -> "Optional[FeedbackDict]":
if not score:
return None
Expand All @@ -68,7 +57,8 @@ def score_to_feedback_dict(
"comment": score.comment,
}

def step_to_step_dict(self, step: LiteralStep) -> "StepDict":
@classmethod
def step_to_stepdict(cls, step: LiteralStep) -> "StepDict":
metadata = step.metadata or {}
input = (step.input or {}).get("content") or (
json.dumps(step.input) if step.input and step.input != {} else ""
Expand All @@ -95,7 +85,7 @@ def step_to_step_dict(self, step: LiteralStep) -> "StepDict":
"id": step.id or "",
"threadId": step.thread_id or "",
"parentId": step.parent_id,
"feedback": self.score_to_feedback_dict(user_feedback),
"feedback": cls.score_to_feedbackdict(user_feedback),
"start": step.start_time,
"end": step.end_time,
"type": step.type or "undefined",
Expand All @@ -110,6 +100,116 @@ def step_to_step_dict(self, step: LiteralStep) -> "StepDict":
"waitForAnswer": metadata.get("waitForAnswer", False),
}

@classmethod
def attachment_to_elementdict(cls, attachment: LiteralAttachment) -> ElementDict:
metadata = attachment.metadata or {}
return {
"chainlitKey": None,
"display": metadata.get("display", "side"),
"language": metadata.get("language"),
"autoPlay": metadata.get("autoPlay", None),
"playerConfig": metadata.get("playerConfig", None),
"page": metadata.get("page"),
"size": metadata.get("size"),
"type": metadata.get("type", "file"),
"forId": attachment.step_id,
"id": attachment.id or "",
"mime": attachment.mime,
"name": attachment.name or "",
"objectKey": attachment.object_key,
"url": attachment.url,
"threadId": attachment.thread_id,
}

@classmethod
def attachment_to_element(
cls, attachment: LiteralAttachment, thread_id: Optional[str] = None
) -> Element:
metadata = attachment.metadata or {}
element_type = metadata.get("type", "file")

element_class = {
"file": File,
"image": Image,
"audio": Audio,
"video": Video,
"text": Text,
"pdf": Pdf,
}.get(element_type, Element)

assert thread_id or attachment.thread_id

element = element_class(
name=attachment.name or "",
display=metadata.get("display", "side"),
language=metadata.get("language"),
size=metadata.get("size"),
url=attachment.url,
mime=attachment.mime,
thread_id=thread_id or attachment.thread_id,
)
element.id = attachment.id or ""
element.for_id = attachment.step_id
element.object_key = attachment.object_key
return element

@classmethod
def step_to_step(cls, step: LiteralStep) -> Step:
chainlit_step = Step(
name=step.name or "",
type=cls.steptype_to_steptype(step.type),
id=step.id,
parent_id=step.parent_id,
thread_id=step.thread_id or None,
)
chainlit_step.start = step.start_time
chainlit_step.end = step.end_time
chainlit_step.created_at = step.created_at
chainlit_step.input = step.input.get("content", "") if step.input else ""
chainlit_step.output = step.output.get("content", "") if step.output else ""
chainlit_step.is_error = bool(step.error)
chainlit_step.metadata = step.metadata or {}
chainlit_step.tags = step.tags
chainlit_step.generation = step.generation

if step.attachments:
chainlit_step.elements = [
cls.attachment_to_element(attachment, chainlit_step.thread_id)
for attachment in step.attachments
]

return chainlit_step

@classmethod
def thread_to_threaddict(cls, thread: LiteralThread) -> ThreadDict:
return {
"id": thread.id,
"createdAt": getattr(thread, "created_at", ""),
"name": thread.name,
"userId": thread.participant_id,
"userIdentifier": thread.participant_identifier,
"tags": thread.tags,
"metadata": thread.metadata,
"steps": [cls.step_to_stepdict(step) for step in thread.steps]
if thread.steps
else [],
"elements": [
cls.attachment_to_elementdict(attachment)
for step in thread.steps
for attachment in step.attachments
]
if thread.steps
else [],
}


class LiteralDataLayer(BaseDataLayer):
def __init__(self, api_key: str, server: Optional[str]):
from literalai import AsyncLiteralClient

self.client = AsyncLiteralClient(api_key=api_key, url=server)
logger.info("Chainlit data layer initialized")

async def build_debug_url(self) -> str:
try:
project_id = await self.client.api.get_my_project_id()
Expand Down Expand Up @@ -239,7 +339,7 @@ async def get_element(
attachment = await self.client.api.get_attachment(id=element_id)
if not attachment:
return None
return self.attachment_to_element_dict(attachment)
return LiteralToChainlitConverter.attachment_to_elementdict(attachment)

@queue_until_user_message()
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
Expand Down Expand Up @@ -339,32 +439,41 @@ async def list_threads(
filters=literal_filters,
order_by={"column": "createdAt", "direction": "DESC"},
)

chainlit_threads = [
*map(LiteralToChainlitConverter.thread_to_threaddict, literal_response.data)
]

return PaginatedResponse(
pageInfo=PageInfo(
hasNextPage=literal_response.pageInfo.hasNextPage,
startCursor=literal_response.pageInfo.startCursor,
endCursor=literal_response.pageInfo.endCursor,
hasNextPage=literal_response.page_info.has_next_page,
startCursor=literal_response.page_info.start_cursor,
endCursor=literal_response.page_info.end_cursor,
),
data=literal_response.data,
data=chainlit_threads,
)

async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
from chainlit.step import check_add_step_in_cot, stub_step

async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
thread = await self.client.api.get_thread(id=thread_id)
if not thread:
return None
elements = [] # List[ElementDict]
steps = [] # List[StepDict]

elements: List[ElementDict] = []
steps: List[StepDict] = []
if thread.steps:
for step in thread.steps:
for attachment in step.attachments:
elements.append(self.attachment_to_element_dict(attachment))

if check_add_step_in_cot(step):
steps.append(self.step_to_step_dict(step))
elements.append(
LiteralToChainlitConverter.attachment_to_elementdict(attachment)
)

chainlit_step = LiteralToChainlitConverter.step_to_step(step)
if check_add_step_in_cot(chainlit_step):
steps.append(
LiteralToChainlitConverter.step_to_stepdict(step)
) # TODO: chainlit_step.to_dict()
else:
steps.append(stub_step(step))
steps.append(stub_step(chainlit_step))

return {
"createdAt": thread.created_at or "",
Expand Down
3 changes: 2 additions & 1 deletion backend/chainlit/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class ElementDict(TypedDict):

@dataclass
class Element:
# Thread id
thread_id: str = Field(default_factory=lambda: context.session.thread_id)
# The type of the element. This will be used to determine how to display the element in the UI.
type: ClassVar[ElementType]
# Name of the element, this will be used to reference the element in the UI.
Expand Down Expand Up @@ -88,7 +90,6 @@ def __post_init__(self) -> None:
trace_event(f"init {self.__class__.__name__}")
self.persisted = False
self.updatable = False
self.thread_id = context.session.thread_id

if not self.url and not self.path and not self.content:
raise ValueError("Must provide url, path or content to instantiate element")
Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from literalai import ChatGeneration, CompletionGeneration, GenerationMessage
from literalai.helper import utc_now
from literalai.step import TrueStepType
from literalai.observability.step import TrueStepType

DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]

Expand Down
2 changes: 1 addition & 1 deletion backend/chainlit/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
FileDict,
)
from literalai.helper import utc_now
from literalai.step import MessageStepType
from literalai.observability.step import MessageStepType


class MessageBase(ABC):
Expand Down
7 changes: 4 additions & 3 deletions backend/chainlit/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from chainlit.types import FeedbackDict
from literalai import BaseGeneration
from literalai.helper import utc_now
from literalai.step import StepType, TrueStepType
from literalai.observability.step import StepType, TrueStepType


def check_add_step_in_cot(step: "Step"):
Expand All @@ -30,7 +30,7 @@ def check_add_step_in_cot(step: "Step"):
return True


def stub_step(step: "Step"):
def stub_step(step: "Step") -> "StepDict":
return {
"type": step.type,
"name": step.name,
Expand Down Expand Up @@ -189,12 +189,13 @@ def __init__(
tags: Optional[List[str]] = None,
language: Optional[str] = None,
show_input: Union[bool, str] = "json",
thread_id: Optional[str] = None,
):
trace_event(f"init {self.__class__.__name__} {type}")
time.sleep(0.001)
self._input = ""
self._output = ""
self.thread_id = context.session.thread_id
self.thread_id = thread_id or context.session.thread_id
self.name = name or ""
self.type = type
self.id = id or str(uuid.uuid4())
Expand Down
Loading

0 comments on commit a0a8fa7

Please sign in to comment.