diff --git a/backend/chainlit/data/literalai.py b/backend/chainlit/data/literalai.py index 694198e1d2..92875b045c 100644 --- a/backend/chainlit/data/literalai.py +++ b/backend/chainlit/data/literalai.py @@ -5,6 +5,7 @@ from chainlit.data.base import BaseDataLayer from chainlit.data.utils import queue_until_user_message from chainlit.logger import logger +from chainlit.step import Step, TrueStepType, StepType from chainlit.types import ( Feedback, PageInfo, @@ -14,50 +15,28 @@ ThreadFilter, ) from chainlit.user import PersistedUser, User +from chainlit.element import Element, ElementDict +from chainlit.step import FeedbackDict, StepDict + from httpx import HTTPStatusError, RequestError -from literalai import Attachment +from literalai import Attachment, Thread as LiteralThread from literalai import Score as LiteralScore from literalai import Step as LiteralStep from literalai.observability.filter import threads_filters as LiteralThreadsFilters from literalai.observability.step import StepDict as LiteralStepDict -if TYPE_CHECKING: - from chainlit.element import Element, ElementDict - from chainlit.step import FeedbackDict, StepDict - - -_data_layer: Optional[BaseDataLayer] = None - - -class LiteralDataLayer(BaseDataLayer): - def __init__(self, api_key: str, server: Optional[str]): - from literalai import AsyncLiteralClient - self.client = AsyncLiteralClient(api_key=api_key, url=server) - logger.info("Chainlit data layer initialized") - - def attachment_to_element_dict(self, attachment: Attachment) -> "ElementDict": - metadata = attachment.metadata or {} - return { - "chainlitKey": None, - "display": metadata.get("display", "side"), - "language": metadata.get("language"), - "autoPlay": metadata.get("autoPlay", None), - "playerConfig": metadata.get("playerConfig", None), - "page": metadata.get("page"), - "size": metadata.get("size"), - "type": metadata.get("type", "file"), - "forId": attachment.step_id, - "id": attachment.id or "", - "mime": attachment.mime, - "name": attachment.name or "", - "objectKey": attachment.object_key, - "url": attachment.url, - "threadId": attachment.thread_id, - } +class LiteralToChainlitConverter: + @classmethod + def steptype_to_steptype(cls, step_type: Optional[StepType]) -> TrueStepType: + if step_type in ["user_message", "assistant_message", "system_message"]: + return "undefined" + return cast(TrueStepType, step_type or "undefined") - def score_to_feedback_dict( - self, score: Optional[LiteralScore] + @classmethod + def score_to_feedbackdict( + cls, + score: Optional[LiteralScore], ) -> "Optional[FeedbackDict]": if not score: return None @@ -68,7 +47,8 @@ def score_to_feedback_dict( "comment": score.comment, } - def step_to_step_dict(self, step: LiteralStep) -> "StepDict": + @classmethod + def step_to_stepdict(cls, step: LiteralStep) -> "StepDict": metadata = step.metadata or {} input = (step.input or {}).get("content") or ( json.dumps(step.input) if step.input and step.input != {} else "" @@ -95,7 +75,7 @@ def step_to_step_dict(self, step: LiteralStep) -> "StepDict": "id": step.id or "", "threadId": step.thread_id or "", "parentId": step.parent_id, - "feedback": self.score_to_feedback_dict(user_feedback), + "feedback": cls.score_to_feedbackdict(user_feedback), "start": step.start_time, "end": step.end_time, "type": step.type or "undefined", @@ -110,6 +90,113 @@ def step_to_step_dict(self, step: LiteralStep) -> "StepDict": "waitForAnswer": metadata.get("waitForAnswer", False), } + @classmethod + def attachment_to_elementdict(cls, attachment: Attachment) -> ElementDict: + metadata = attachment.metadata or {} + return { + "chainlitKey": None, + "display": metadata.get("display", "side"), + "language": metadata.get("language"), + "autoPlay": metadata.get("autoPlay", None), + "playerConfig": metadata.get("playerConfig", None), + "page": metadata.get("page"), + "size": metadata.get("size"), + "type": metadata.get("type", "file"), + "forId": attachment.step_id, + "id": attachment.id or "", + "mime": attachment.mime, + "name": attachment.name or "", + "objectKey": attachment.object_key, + "url": attachment.url, + "threadId": attachment.thread_id, + } + + @classmethod + def attachment_to_element(cls, attachment: Attachment) -> Element: + from chainlit.element import Element, File, Image, Audio, Video, Text, Pdf + + metadata = attachment.metadata or {} + element_type = metadata.get("type", "file") + + element_class = { + "file": File, + "image": Image, + "audio": Audio, + "video": Video, + "text": Text, + "pdf": Pdf, + }.get(element_type, Element) + + element = element_class( + name=attachment.name or "", + display=metadata.get("display", "side"), + language=metadata.get("language"), + size=metadata.get("size"), + url=attachment.url, + mime=attachment.mime, + thread_id=attachment.thread_id, + ) + element.id = attachment.id or "" + element.for_id = attachment.step_id + element.object_key = attachment.object_key + return element + + @classmethod + def step_to_step(cls, step: LiteralStep) -> Step: + chainlit_step = Step( + name=step.name or "", + type=cls.steptype_to_steptype(step.type), + id=step.id, + parent_id=step.parent_id, + thread_id=step.thread_id or None, + ) + chainlit_step.start = step.start_time + chainlit_step.end = step.end_time + chainlit_step.created_at = step.created_at + chainlit_step.input = step.input.get("content", "") if step.input else "" + chainlit_step.output = step.output.get("content", "") if step.output else "" + chainlit_step.is_error = bool(step.error) + chainlit_step.metadata = step.metadata or {} + chainlit_step.tags = step.tags + chainlit_step.generation = step.generation + + if step.attachments: + chainlit_step.elements = [ + cls.attachment_to_element(attachment) for attachment in step.attachments + ] + + return chainlit_step + + @classmethod + def thread_to_threaddict(cls, thread: LiteralThread) -> ThreadDict: + return { + "id": thread.id, + "createdAt": getattr(thread, "created_at", ""), + "name": thread.name, + "userId": thread.participant_id, + "userIdentifier": thread.participant_identifier, + "tags": thread.tags, + "metadata": thread.metadata, + "steps": [cls.step_to_stepdict(step) for step in thread.steps] + if thread.steps + else [], + "elements": [ + cls.attachment_to_elementdict(attachment) + for step in thread.steps + for attachment in step.attachments + ] + if thread.steps + else [], + } + + +class LiteralDataLayer(BaseDataLayer): + def __init__(self, api_key: str, server: Optional[str]): + from literalai import AsyncLiteralClient + + self.client = AsyncLiteralClient(api_key=api_key, url=server) + logger.info("Chainlit data layer initialized") + async def build_debug_url(self) -> str: try: project_id = await self.client.api.get_my_project_id() @@ -239,7 +326,7 @@ async def get_element( attachment = await self.client.api.get_attachment(id=element_id) if not attachment: return None - return self.attachment_to_element_dict(attachment) + return LiteralToChainlitConverter.attachment_to_elementdict(attachment) @queue_until_user_message() async def delete_element(self, element_id: str, thread_id: Optional[str] = None): @@ -339,13 +426,18 @@ async def list_threads( filters=literal_filters, order_by={"column": "createdAt", "direction": "DESC"}, ) + + chainlit_threads = [ + *map(LiteralToChainlitConverter.thread_to_threaddict, literal_response.data) + ] + return PaginatedResponse( pageInfo=PageInfo( - hasNextPage=literal_response.pageInfo.hasNextPage, - startCursor=literal_response.pageInfo.startCursor, - endCursor=literal_response.pageInfo.endCursor, + hasNextPage=literal_response.page_info.has_next_page, + startCursor=literal_response.page_info.start_cursor, + endCursor=literal_response.page_info.end_cursor, ), - data=literal_response.data, + data=chainlit_threads, ) async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": @@ -359,12 +451,17 @@ async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": if thread.steps: for step in thread.steps: for attachment in step.attachments: - elements.append(self.attachment_to_element_dict(attachment)) - - if check_add_step_in_cot(step): - steps.append(self.step_to_step_dict(step)) + elements.append( + LiteralToChainlitConverter.attachment_to_elementdict(attachment) + ) + + chainlit_step = LiteralToChainlitConverter.step_to_step(step) + if check_add_step_in_cot(chainlit_step): + steps.append( + LiteralToChainlitConverter.step_to_stepdict(step) + ) # TODO: chainlit_step.to_dict() else: - steps.append(stub_step(step)) + steps.append(stub_step(chainlit_step)) return { "createdAt": thread.created_at or "", diff --git a/backend/chainlit/element.py b/backend/chainlit/element.py index 3b58b38714..55df0dc0c5 100644 --- a/backend/chainlit/element.py +++ b/backend/chainlit/element.py @@ -83,12 +83,15 @@ class Element: language: Optional[str] = None # Mime type, infered based on content if not provided mime: Optional[str] = None + # Thread id + thread_id: Optional[str] = None def __post_init__(self) -> None: trace_event(f"init {self.__class__.__name__}") self.persisted = False self.updatable = False - self.thread_id = context.session.thread_id + if not self.thread_id: + self.thread_id = context.session.thread_id if not self.url and not self.path and not self.content: raise ValueError("Must provide url, path or content to instantiate element") diff --git a/backend/chainlit/step.py b/backend/chainlit/step.py index eb126817f5..0ee355bd65 100644 --- a/backend/chainlit/step.py +++ b/backend/chainlit/step.py @@ -189,12 +189,13 @@ def __init__( tags: Optional[List[str]] = None, language: Optional[str] = None, show_input: Union[bool, str] = "json", + thread_id: Optional[str] = None, ): trace_event(f"init {self.__class__.__name__} {type}") time.sleep(0.001) self._input = "" self._output = "" - self.thread_id = context.session.thread_id + self.thread_id = thread_id or context.session.thread_id self.name = name or "" self.type = type self.id = id or str(uuid.uuid4()) diff --git a/backend/tests/data/test_literalai.py b/backend/tests/data/test_literalai.py index c3b9b803c4..e7822c95ff 100644 --- a/backend/tests/data/test_literalai.py +++ b/backend/tests/data/test_literalai.py @@ -1,20 +1,31 @@ +import pytest import datetime import uuid from unittest.mock import ANY, AsyncMock, Mock, patch -import pytest +from literalai.observability.thread import ThreadDict as LiteralThreadDict +from literalai.observability.step import ( + AttachmentDict as LiteralAttachmentDict, + StepDict as LiteralStepDict, +) from httpx import HTTPStatusError, RequestError -from literalai import AsyncLiteralClient +from literalai import AsyncLiteralClient, PaginatedResponse, PageInfo, Thread, UserDict from literalai import Step as LiteralStep from literalai import Thread as LiteralThread from literalai import User as LiteralUser +from literalai import Score as LiteralScore +from literalai import Attachment from literalai.api import AsyncLiteralAPI -from chainlit.data.literalai import LiteralDataLayer -from chainlit.element import Text -from chainlit.step import StepDict -from chainlit.types import Feedback, Pagination, ThreadFilter +from chainlit.step import Step, StepDict +from chainlit.element import File, Image, Audio, Video, Text, Pdf +from chainlit.data.literalai import LiteralDataLayer, LiteralToChainlitConverter +from chainlit.types import ( + Feedback, + Pagination, + ThreadFilter, +) from chainlit.user import PersistedUser, User @@ -465,23 +476,23 @@ async def test_list_threads( test_filters: ThreadFilter, test_pagination: Pagination, ): - mock_response = Mock() - mock_response.pageInfo = Mock( - hasNextPage=True, startCursor="start_cursor", endCursor="end_cursor" + response: PaginatedResponse[Thread] = PaginatedResponse( + page_info=PageInfo( + has_next_page=True, start_cursor="start_cursor", end_cursor="end_cursor" + ), + data=[ + Thread( + id="thread1", + name="Thread 1", + ), + Thread( + id="thread2", + name="Thread 2", + ), + ], ) - mock_response.data = [ - { - "id": "thread1", - "name": "Thread 1", - "createdAt": "2023-01-01T00:00:00Z", - }, - { - "id": "thread2", - "name": "Thread 2", - "createdAt": "2023-01-02T00:00:00Z", - }, - ] - mock_literal_client.api.list_threads.return_value = mock_response + + mock_literal_client.api.list_threads.return_value = response result = await literal_data_layer.list_threads(test_pagination, test_filters) @@ -760,10 +771,22 @@ async def test_update_step( ) -async def test_score_to_feedback_dict(literal_data_layer: LiteralDataLayer): - from literalai import Score as LiteralScore +def test_steptype_to_steptype(): + assert ( + LiteralToChainlitConverter.steptype_to_steptype("user_message") == "undefined" + ) + assert ( + LiteralToChainlitConverter.steptype_to_steptype("assistant_message") + == "undefined" + ) + assert ( + LiteralToChainlitConverter.steptype_to_steptype("system_message") == "undefined" + ) + assert LiteralToChainlitConverter.steptype_to_steptype("tool") == "tool" + assert LiteralToChainlitConverter.steptype_to_steptype(None) == "undefined" + - # Test with a valid score +def test_score_to_feedbackdict(): score = LiteralScore( id="test_score_id", step_id="test_step_id", @@ -774,7 +797,7 @@ async def test_score_to_feedback_dict(literal_data_layer: LiteralDataLayer): dataset_experiment_item_id=None, tags=None, ) - feedback_dict = literal_data_layer.score_to_feedback_dict(score) + feedback_dict = LiteralToChainlitConverter.score_to_feedbackdict(score) assert feedback_dict == { "id": "test_score_id", "forId": "test_step_id", @@ -782,19 +805,254 @@ async def test_score_to_feedback_dict(literal_data_layer: LiteralDataLayer): "comment": "Great job!", } - # Test with None score - assert literal_data_layer.score_to_feedback_dict(None) is None + assert LiteralToChainlitConverter.score_to_feedbackdict(None) is None - # Test with score value 0 score.value = 0 - feedback_dict = literal_data_layer.score_to_feedback_dict(score) + feedback_dict = LiteralToChainlitConverter.score_to_feedbackdict(score) assert feedback_dict is not None assert feedback_dict["value"] == 0 - # Test with missing id or step_id score.id = None score.step_id = None - feedback_dict = literal_data_layer.score_to_feedback_dict(score) + feedback_dict = LiteralToChainlitConverter.score_to_feedbackdict(score) assert feedback_dict is not None assert feedback_dict["id"] == "" assert feedback_dict["forId"] == "" + + +def test_step_to_stepdict(): + literal_step = LiteralStep.from_dict( + { + "id": "test_step_id", + "threadId": "test_thread_id", + "type": "user_message", + "name": "Test Step", + "input": {"content": "test input"}, + "output": {"content": "test output"}, + "startTime": "2023-01-01T00:00:00Z", + "endTime": "2023-01-01T00:00:01Z", + "createdAt": "2023-01-01T00:00:00Z", + "metadata": {"showInput": True, "language": "en"}, + "error": None, + "scores": [ + { + "id": "test_score_id", + "stepId": "test_step_id", + "value": 1, + "comment": "Great job!", + "name": "user-feedback", + "type": "HUMAN", + } + ], + } + ) + + step_dict = LiteralToChainlitConverter.step_to_stepdict(literal_step) + + assert step_dict.get("id") == "test_step_id" + assert step_dict.get("threadId") == "test_thread_id" + assert step_dict.get("type") == "user_message" + assert step_dict.get("name") == "Test Step" + assert step_dict.get("input") == "test input" + assert step_dict.get("output") == "test output" + assert step_dict.get("start") == "2023-01-01T00:00:00Z" + assert step_dict.get("end") == "2023-01-01T00:00:01Z" + assert step_dict.get("createdAt") == "2023-01-01T00:00:00Z" + assert step_dict.get("showInput") == True + assert step_dict.get("language") == "en" + assert step_dict.get("isError") == False + assert step_dict.get("feedback") == { + "id": "test_score_id", + "forId": "test_step_id", + "value": 1, + "comment": "Great job!", + } + + +def test_attachment_to_elementdict(): + attachment = Attachment( + id="test_attachment_id", + step_id="test_step_id", + thread_id="test_thread_id", + name="test.txt", + mime="text/plain", + url="https://example.com/test.txt", + object_key="test_object_key", + metadata={ + "display": "side", + "language": "python", + "type": "file", + "size": "large", + }, + ) + + element_dict = LiteralToChainlitConverter.attachment_to_elementdict(attachment) + + assert element_dict["id"] == "test_attachment_id" + assert element_dict["forId"] == "test_step_id" + assert element_dict["threadId"] == "test_thread_id" + assert element_dict["name"] == "test.txt" + assert element_dict["mime"] == "text/plain" + assert element_dict["url"] == "https://example.com/test.txt" + assert element_dict["objectKey"] == "test_object_key" + assert element_dict["display"] == "side" + assert element_dict["language"] == "python" + assert element_dict["type"] == "file" + assert element_dict["size"] == "large" + + +def test_attachment_to_element(): + attachment = Attachment( + id="test_attachment_id", + step_id="test_step_id", + thread_id="test_thread_id", + name="test.txt", + mime="text/plain", + url="https://example.com/test.txt", + object_key="test_object_key", + metadata={ + "display": "side", + "language": "python", + "type": "text", + "size": "small", + }, + ) + + element = LiteralToChainlitConverter.attachment_to_element(attachment) + + assert isinstance(element, Text) + assert element.id == "test_attachment_id" + assert element.for_id == "test_step_id" + assert element.thread_id == "test_thread_id" + assert element.name == "test.txt" + assert element.mime == "text/plain" + assert element.url == "https://example.com/test.txt" + assert element.object_key == "test_object_key" + assert element.display == "side" + assert element.language == "python" + assert element.size == "small" + + # Test other element types + for element_type in ["file", "image", "audio", "video", "pdf"]: + attachment.metadata = {"type": element_type, "size": "small"} + + element = LiteralToChainlitConverter.attachment_to_element(attachment) + assert isinstance( + element, + { + "file": File, + "image": Image, + "audio": Audio, + "video": Video, + "text": Text, + "pdf": Pdf, + }[element_type], + ) + + +def test_step_to_step(): + literal_step = LiteralStep.from_dict( + { + "id": "test_step_id", + "threadId": "test_thread_id", + "type": "user_message", + "name": "Test Step", + "input": {"content": "test input"}, + "output": {"content": "test output"}, + "startTime": "2023-01-01T00:00:00Z", + "endTime": "2023-01-01T00:00:01Z", + "createdAt": "2023-01-01T00:00:00Z", + "metadata": {"showInput": True, "language": "en"}, + "error": None, + "attachments": [ + { + "id": "test_attachment_id", + "stepId": "test_step_id", + "threadId": "test_thread_id", + "name": "test.txt", + "mime": "text/plain", + "url": "https://example.com/test.txt", + "objectKey": "test_object_key", + "metadata": { + "display": "side", + "language": "python", + "type": "text", + }, + } + ], + } + ) + + chainlit_step = LiteralToChainlitConverter.step_to_step(literal_step) + + assert isinstance(chainlit_step, Step) + assert chainlit_step.id == "test_step_id" + assert chainlit_step.thread_id == "test_thread_id" + assert chainlit_step.type == "undefined" + assert chainlit_step.name == "Test Step" + assert chainlit_step.input == "test input" + assert chainlit_step.output == "test output" + assert chainlit_step.start == "2023-01-01T00:00:00Z" + assert chainlit_step.end == "2023-01-01T00:00:01Z" + assert chainlit_step.created_at == "2023-01-01T00:00:00Z" + assert chainlit_step.metadata == {"showInput": True, "language": "en"} + assert not chainlit_step.is_error + assert chainlit_step.elements is not None + assert len(chainlit_step.elements) == 1 + assert isinstance(chainlit_step.elements[0], Text) + + +def test_thread_to_threaddict(): + attachment_dict = LiteralAttachmentDict( + id="test_attachment_id", + stepId="test_step_id", + threadId="test_thread_id", + name="test.txt", + mime="text/plain", + url="https://example.com/test.txt", + objectKey="test_object_key", + metadata={ + "display": "side", + "language": "python", + "type": "text", + }, + ) + step_dict = LiteralStepDict( + id="test_step_id", + threadId="test_thread_id", + type="user_message", + name="Test Step", + input={"content": "test input"}, + output={"content": "test output"}, + startTime="2023-01-01T00:00:00Z", + endTime="2023-01-01T00:00:01Z", + createdAt="2023-01-01T00:00:00Z", + metadata={"showInput": True, "language": "en"}, + error=None, + attachments=[attachment_dict], + ) + literal_thread = LiteralThread.from_dict( + LiteralThreadDict( + id="test_thread_id", + name="Test Thread", + createdAt="2023-01-01T00:00:00Z", + participant=UserDict(id="test_user_id", identifier="test_user_identifier_"), + tags=["tag1", "tag2"], + metadata={"key": "value"}, + steps=[step_dict], + ) + ) + + thread_dict = LiteralToChainlitConverter.thread_to_threaddict(literal_thread) + + assert thread_dict["id"] == "test_thread_id" + assert thread_dict["name"] == "Test Thread" + assert thread_dict["createdAt"] == "2023-01-01T00:00:00Z" + assert thread_dict["userId"] == "test_user_id" + assert thread_dict["userIdentifier"] == "test_user_identifier_" + assert thread_dict["tags"] == ["tag1", "tag2"] + assert thread_dict["metadata"] == {"key": "value"} + assert thread_dict["steps"] is not None + assert len(thread_dict["steps"]) == 1 + assert thread_dict["elements"] is not None + assert len(thread_dict["elements"]) == 1