From 38439cfe237a1b64c1cd32b5dc949fc456464d00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Sirieix?= Date: Mon, 9 Oct 2023 18:39:01 +0200 Subject: [PATCH] Reduce perceived latency for users (#463) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat/offload persistence * feat/offload persistence element * fix/fix review --------- Co-authored-by: Clément Sirieix --- backend/chainlit/element.py | 27 ++++++++++++------ backend/chainlit/message.py | 56 ++++++++++++++++++++++++++----------- 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/backend/chainlit/element.py b/backend/chainlit/element.py index 59b73180ba..12bb89901b 100644 --- a/backend/chainlit/element.py +++ b/backend/chainlit/element.py @@ -10,8 +10,10 @@ from chainlit.client.cloud import ChainlitCloudClient from chainlit.context import context from chainlit.data import chainlit_client +from chainlit.logger import logger from chainlit.telemetry import trace_event from pydantic.dataclasses import Field, dataclass +from syncer import asyncio mime_types = { "text": "text/plain", @@ -98,18 +100,25 @@ async def persist(self, client: ChainlitCloudClient) -> Optional[ElementDict]: ) self.url = upload_res["url"] self.object_key = upload_res["object_key"] + element_dict = await self.with_conversation_id() + + asyncio.create_task(self._persist(element_dict)) - if not self.persisted: - element_dict = await client.create_element( - await self.with_conversation_id() - ) - self.persisted = True - else: - element_dict = await client.update_element( - await self.with_conversation_id() - ) return element_dict + async def _persist(self, element: ElementDict): + if not chainlit_client: + return + + try: + if self.persisted: + await chainlit_client.update_element(element) + else: + await chainlit_client.create_element(element) + self.persisted = True + except Exception as e: + logger.error(f"Failed to persist element: {str(e)}") + async def before_emit(self, element: Dict) -> Dict: return element diff --git a/backend/chainlit/message.py b/backend/chainlit/message.py index 9b40fb0029..5b1315b089 100644 --- a/backend/chainlit/message.py +++ b/backend/chainlit/message.py @@ -14,6 +14,7 @@ from chainlit.prompt import Prompt from chainlit.telemetry import trace_event from chainlit.types import AskFileResponse, AskFileSpec, AskResponse, AskSpec +from syncer import asyncio class MessageBase(ABC): @@ -43,23 +44,27 @@ async def with_conversation_id(self): async def _create(self): msg_dict = await self.with_conversation_id() - if chainlit_client and not self.persisted: - try: - persisted_id = await chainlit_client.create_message(msg_dict) - if persisted_id: - msg_dict["id"] = persisted_id - self.id = persisted_id - self.persisted = True - except Exception as e: - if self.fail_on_persist_error: - raise e - logger.error(f"Failed to persist message: {str(e)}") + asyncio.create_task(self._persist_create(msg_dict)) if not config.features.prompt_playground: msg_dict.pop("prompt", None) - return msg_dict + async def _persist_create(self, message: MessageDict): + if not chainlit_client or self.persisted: + return + + try: + persisted_id = await chainlit_client.create_message(message) + + if persisted_id: + self.id = persisted_id + self.persisted = True + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist message creation: {str(e)}") + async def update( self, ): @@ -69,14 +74,22 @@ async def update( trace_event("update_message") msg_dict = self.to_dict() - - if chainlit_client and self.id: - await chainlit_client.update_message(self.id, msg_dict) - + asyncio.create_task(self._persist_update(msg_dict)) await context.emitter.update_message(msg_dict) return True + async def _persist_update(self, message: MessageDict): + if not chainlit_client or not self.id: + return + + try: + await chainlit_client.update_message(self.id, message) + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist message update: {str(e)}") + async def remove(self): """ Remove a message already sent to the UI. @@ -91,6 +104,17 @@ async def remove(self): return True + async def _persist_remove(self): + if not chainlit_client or not self.id: + return + + try: + await chainlit_client.delete_message(self.id) + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist message deletion: {str(e)}") + async def send(self): if self.content is None: self.content = ""