diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index 279762d3eb..37ae77594a 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -69,7 +69,9 @@ on_message, on_settings_update, on_stop, + on_window_message, password_auth_callback, + send_window_message, set_chat_profiles, set_starters, ) @@ -151,6 +153,8 @@ def acall(self): "CompletionGeneration", "GenerationMessage", "on_logout", + "on_window_message", + "send_window_message", "on_chat_start", "on_chat_end", "on_chat_resume", diff --git a/backend/chainlit/callbacks.py b/backend/chainlit/callbacks.py index f03625da64..106904e3d1 100644 --- a/backend/chainlit/callbacks.py +++ b/backend/chainlit/callbacks.py @@ -6,6 +6,7 @@ from chainlit.action import Action from chainlit.config import config +from chainlit.context import context from chainlit.data.base import BaseDataLayer from chainlit.message import Message from chainlit.oauth_providers import get_configured_oauth_providers @@ -125,6 +126,33 @@ async def with_parent_id(message: Message): return func +@trace +async def send_window_message(data: Any): + """ + Send custom data to the host window via a window.postMessage event. + + Args: + data (Any): The data to send with the event. + """ + await context.emitter.send_window_message(data) + + +@trace +def on_window_message(func: Callable[[str], Any]) -> Callable: + """ + Hook to react to javascript postMessage events coming from the UI. + + Args: + func (Callable[[str], Any]): The function to be called when a window message is received. + Takes the message content as a string parameter. + + Returns: + Callable[[str], Any]: The decorated on_window_message function. + """ + config.code.on_window_message = wrap_user_function(func) + return func + + @trace def on_chat_start(func: Callable) -> Callable: """ diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index 71fa8fca1f..d59da1bca1 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -287,6 +287,7 @@ class CodeSettings: on_chat_end: Optional[Callable[[], Any]] = None on_chat_resume: Optional[Callable[["ThreadDict"], Any]] = None on_message: Optional[Callable[["Message"], Any]] = None + on_window_message: Optional[Callable[[str], Any]] = None on_audio_start: Optional[Callable[[], Any]] = None on_audio_chunk: Optional[Callable[["InputAudioChunk"], Any]] = None on_audio_end: Optional[Callable[[], Any]] = None diff --git a/backend/chainlit/emitter.py b/backend/chainlit/emitter.py index df8a78e9f4..5cc6a905b6 100644 --- a/backend/chainlit/emitter.py +++ b/backend/chainlit/emitter.py @@ -2,6 +2,9 @@ import uuid from typing import Any, Dict, List, Literal, Optional, Union, cast +from literalai.helper import utc_now +from socketio.exceptions import TimeoutError + from chainlit.chat_context import chat_context from chainlit.config import config from chainlit.data import get_data_layer @@ -16,12 +19,10 @@ FileDict, FileReference, MessagePayload, + OutputAudioChunk, ThreadDict, - OutputAudioChunk ) from chainlit.user import PersistedUser -from literalai.helper import utc_now -from socketio.exceptions import TimeoutError class BaseChainlitEmitter: @@ -52,15 +53,15 @@ async def resume_thread(self, thread_dict: ThreadDict): async def send_element(self, element_dict: ElementDict): """Stub method to send an element to the UI.""" pass - + async def update_audio_connection(self, state: Literal["on", "off"]): """Audio connection signaling.""" pass - + async def send_audio_chunk(self, chunk: OutputAudioChunk): """Stub method to send an audio chunk to the UI.""" pass - + async def send_audio_interrupt(self): """Stub method to interrupt the current audio response.""" pass @@ -133,6 +134,10 @@ async def send_action_response( """Send an action response to the UI.""" pass + async def send_window_message(self, data: Any): + """Stub method to send custom data to the host window.""" + pass + class ChainlitEmitter(BaseChainlitEmitter): """ @@ -177,7 +182,7 @@ async def update_audio_connection(self, state: Literal["on", "off"]): async def send_audio_chunk(self, chunk: OutputAudioChunk): """Send an audio chunk to the UI.""" await self.emit("audio_chunk", chunk) - + async def send_audio_interrupt(self): """Method to interrupt the current audio response.""" await self.emit("audio_interrupt", {}) @@ -392,3 +397,7 @@ def send_action_response( return self.emit( "action_response", {"id": id, "status": status, "response": response} ) + + def send_window_message(self, data: Any): + """Send custom data to the host window.""" + return self.emit("window_message", data) diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index 4cfc42fa9a..064c4c2189 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -17,11 +17,7 @@ from chainlit.server import sio from chainlit.session import WebsocketSession from chainlit.telemetry import trace_event -from chainlit.types import ( - InputAudioChunk, - InputAudioChunkPayload, - MessagePayload, -) +from chainlit.types import InputAudioChunk, InputAudioChunkPayload, MessagePayload from chainlit.user_session import user_sessions @@ -313,6 +309,23 @@ async def message(sid, payload: MessagePayload): session.current_task = task +@sio.on("window_message") +async def window_message(sid, data): + """Handle a message send by the host window.""" + session = WebsocketSession.require(sid) + context = init_ws_context(session) + + await context.emitter.task_start() + + if config.code.on_window_message: + try: + await config.code.on_window_message(data) + except asyncio.CancelledError: + pass + finally: + await context.emitter.task_end() + + @sio.on("audio_start") async def audio_start(sid): """Handle audio init.""" @@ -320,10 +333,10 @@ async def audio_start(sid): context = init_ws_context(session) if config.code.on_audio_start: - connected = bool(await config.code.on_audio_start()) - connection_state = "on" if connected else "off" - await context.emitter.update_audio_connection(connection_state) - + connected = bool(await config.code.on_audio_start()) + connection_state = "on" if connected else "off" + await context.emitter.update_audio_connection(connection_state) + @sio.on("audio_chunk") async def audio_chunk(sid, payload: InputAudioChunkPayload): @@ -350,7 +363,7 @@ async def audio_end(sid): if config.code.on_audio_end: await config.code.on_audio_end() - + except asyncio.CancelledError: pass except Exception as e: diff --git a/cypress/e2e/window_message/main.py b/cypress/e2e/window_message/main.py new file mode 100644 index 0000000000..8843e7cb09 --- /dev/null +++ b/cypress/e2e/window_message/main.py @@ -0,0 +1,12 @@ +import chainlit as cl + + +@cl.on_window_message +async def window_message(message: str): + if message.startswith("Client: "): + await cl.send_window_message("Server: World") + + +@cl.on_message +async def message(message: str): + await cl.Message(content="ok").send() diff --git a/cypress/e2e/window_message/public/iframe.html b/cypress/e2e/window_message/public/iframe.html new file mode 100644 index 0000000000..9272f2c1fd --- /dev/null +++ b/cypress/e2e/window_message/public/iframe.html @@ -0,0 +1,18 @@ + + +
+