diff --git a/backend/chainlit/__init__.py b/backend/chainlit/__init__.py index 279762d3eb..37ae77594a 100644 --- a/backend/chainlit/__init__.py +++ b/backend/chainlit/__init__.py @@ -69,7 +69,9 @@ on_message, on_settings_update, on_stop, + on_window_message, password_auth_callback, + send_window_message, set_chat_profiles, set_starters, ) @@ -151,6 +153,8 @@ def acall(self): "CompletionGeneration", "GenerationMessage", "on_logout", + "on_window_message", + "send_window_message", "on_chat_start", "on_chat_end", "on_chat_resume", diff --git a/backend/chainlit/callbacks.py b/backend/chainlit/callbacks.py index f03625da64..106904e3d1 100644 --- a/backend/chainlit/callbacks.py +++ b/backend/chainlit/callbacks.py @@ -6,6 +6,7 @@ from chainlit.action import Action from chainlit.config import config +from chainlit.context import context from chainlit.data.base import BaseDataLayer from chainlit.message import Message from chainlit.oauth_providers import get_configured_oauth_providers @@ -125,6 +126,33 @@ async def with_parent_id(message: Message): return func +@trace +async def send_window_message(data: Any): + """ + Send custom data to the host window via a window.postMessage event. + + Args: + data (Any): The data to send with the event. + """ + await context.emitter.send_window_message(data) + + +@trace +def on_window_message(func: Callable[[str], Any]) -> Callable: + """ + Hook to react to javascript postMessage events coming from the UI. + + Args: + func (Callable[[str], Any]): The function to be called when a window message is received. + Takes the message content as a string parameter. + + Returns: + Callable[[str], Any]: The decorated on_window_message function. + """ + config.code.on_window_message = wrap_user_function(func) + return func + + @trace def on_chat_start(func: Callable) -> Callable: """ diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index 71fa8fca1f..d59da1bca1 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -287,6 +287,7 @@ class CodeSettings: on_chat_end: Optional[Callable[[], Any]] = None on_chat_resume: Optional[Callable[["ThreadDict"], Any]] = None on_message: Optional[Callable[["Message"], Any]] = None + on_window_message: Optional[Callable[[str], Any]] = None on_audio_start: Optional[Callable[[], Any]] = None on_audio_chunk: Optional[Callable[["InputAudioChunk"], Any]] = None on_audio_end: Optional[Callable[[], Any]] = None diff --git a/backend/chainlit/emitter.py b/backend/chainlit/emitter.py index df8a78e9f4..5cc6a905b6 100644 --- a/backend/chainlit/emitter.py +++ b/backend/chainlit/emitter.py @@ -2,6 +2,9 @@ import uuid from typing import Any, Dict, List, Literal, Optional, Union, cast +from literalai.helper import utc_now +from socketio.exceptions import TimeoutError + from chainlit.chat_context import chat_context from chainlit.config import config from chainlit.data import get_data_layer @@ -16,12 +19,10 @@ FileDict, FileReference, MessagePayload, + OutputAudioChunk, ThreadDict, - OutputAudioChunk ) from chainlit.user import PersistedUser -from literalai.helper import utc_now -from socketio.exceptions import TimeoutError class BaseChainlitEmitter: @@ -52,15 +53,15 @@ async def resume_thread(self, thread_dict: ThreadDict): async def send_element(self, element_dict: ElementDict): """Stub method to send an element to the UI.""" pass - + async def update_audio_connection(self, state: Literal["on", "off"]): """Audio connection signaling.""" pass - + async def send_audio_chunk(self, chunk: OutputAudioChunk): """Stub method to send an audio chunk to the UI.""" pass - + async def send_audio_interrupt(self): """Stub method to interrupt the current audio response.""" pass @@ -133,6 +134,10 @@ async def send_action_response( """Send an action response to the UI.""" pass + async def send_window_message(self, data: Any): + """Stub method to send custom data to the host window.""" + pass + class ChainlitEmitter(BaseChainlitEmitter): """ @@ -177,7 +182,7 @@ async def update_audio_connection(self, state: Literal["on", "off"]): async def send_audio_chunk(self, chunk: OutputAudioChunk): """Send an audio chunk to the UI.""" await self.emit("audio_chunk", chunk) - + async def send_audio_interrupt(self): """Method to interrupt the current audio response.""" await self.emit("audio_interrupt", {}) @@ -392,3 +397,7 @@ def send_action_response( return self.emit( "action_response", {"id": id, "status": status, "response": response} ) + + def send_window_message(self, data: Any): + """Send custom data to the host window.""" + return self.emit("window_message", data) diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py index 4cfc42fa9a..064c4c2189 100644 --- a/backend/chainlit/socket.py +++ b/backend/chainlit/socket.py @@ -17,11 +17,7 @@ from chainlit.server import sio from chainlit.session import WebsocketSession from chainlit.telemetry import trace_event -from chainlit.types import ( - InputAudioChunk, - InputAudioChunkPayload, - MessagePayload, -) +from chainlit.types import InputAudioChunk, InputAudioChunkPayload, MessagePayload from chainlit.user_session import user_sessions @@ -313,6 +309,23 @@ async def message(sid, payload: MessagePayload): session.current_task = task +@sio.on("window_message") +async def window_message(sid, data): + """Handle a message send by the host window.""" + session = WebsocketSession.require(sid) + context = init_ws_context(session) + + await context.emitter.task_start() + + if config.code.on_window_message: + try: + await config.code.on_window_message(data) + except asyncio.CancelledError: + pass + finally: + await context.emitter.task_end() + + @sio.on("audio_start") async def audio_start(sid): """Handle audio init.""" @@ -320,10 +333,10 @@ async def audio_start(sid): context = init_ws_context(session) if config.code.on_audio_start: - connected = bool(await config.code.on_audio_start()) - connection_state = "on" if connected else "off" - await context.emitter.update_audio_connection(connection_state) - + connected = bool(await config.code.on_audio_start()) + connection_state = "on" if connected else "off" + await context.emitter.update_audio_connection(connection_state) + @sio.on("audio_chunk") async def audio_chunk(sid, payload: InputAudioChunkPayload): @@ -350,7 +363,7 @@ async def audio_end(sid): if config.code.on_audio_end: await config.code.on_audio_end() - + except asyncio.CancelledError: pass except Exception as e: diff --git a/cypress/e2e/window_message/main.py b/cypress/e2e/window_message/main.py new file mode 100644 index 0000000000..8843e7cb09 --- /dev/null +++ b/cypress/e2e/window_message/main.py @@ -0,0 +1,12 @@ +import chainlit as cl + + +@cl.on_window_message +async def window_message(message: str): + if message.startswith("Client: "): + await cl.send_window_message("Server: World") + + +@cl.on_message +async def message(message: str): + await cl.Message(content="ok").send() diff --git a/cypress/e2e/window_message/public/iframe.html b/cypress/e2e/window_message/public/iframe.html new file mode 100644 index 0000000000..9272f2c1fd --- /dev/null +++ b/cypress/e2e/window_message/public/iframe.html @@ -0,0 +1,18 @@ + + + + Chainlit iframe + + +

Chainlit iframe

+ +
No message received
+ + + diff --git a/cypress/e2e/window_message/spec.cy.ts b/cypress/e2e/window_message/spec.cy.ts new file mode 100644 index 0000000000..6a1643d4cd --- /dev/null +++ b/cypress/e2e/window_message/spec.cy.ts @@ -0,0 +1,28 @@ +import { runTestServer } from '../../support/testUtils'; + +const getIframeWindow = () => { + return cy + .get('iframe[data-cy="the-frame"]') + .its('0.contentWindow') + .should('exist'); +}; + +describe('Window Message', () => { + before(() => { + runTestServer(); + }); + + it('should be able to send and receive window messages', () => { + cy.visit('/public/iframe.html'); + + cy.get('div#message').should('contain', 'No message received'); + + getIframeWindow().then((win) => { + cy.wait(1000).then(() => { + win.postMessage('Client: Hello', '*'); + }); + }); + + cy.get('div#message').should('contain', 'Server: World'); + }); +}); diff --git a/frontend/src/AppWrapper.tsx b/frontend/src/AppWrapper.tsx index b32f10abea..d7d249cbaa 100644 --- a/frontend/src/AppWrapper.tsx +++ b/frontend/src/AppWrapper.tsx @@ -3,12 +3,13 @@ import { useEffect } from 'react'; import { useTranslation } from 'react-i18next'; import getRouterBasename from 'utils/router'; -import { useApi, useAuth, useConfig } from '@chainlit/react-client'; +import { useApi, useAuth, useChatInteract, useConfig } from '@chainlit/react-client'; export default function AppWrapper() { const { isAuthenticated, isReady } = useAuth(); const { language: languageInUse } = useConfig(); const { i18n } = useTranslation(); + const { windowMessage } = useChatInteract(); function handleChangeLanguage(languageBundle: any): void { i18n.addResourceBundle(languageInUse, 'translation', languageBundle); @@ -33,6 +34,14 @@ export default function AppWrapper() { handleChangeLanguage(translations.translation); }, [translations]); + useEffect(() => { + const handleWindowMessage = (event: MessageEvent) => { + windowMessage(event.data); + } + window.addEventListener('message', handleWindowMessage); + return () => window.removeEventListener('message', handleWindowMessage); + }, [windowMessage]); + if (!isReady) { return null; } diff --git a/libs/react-client/src/useChatInteract.ts b/libs/react-client/src/useChatInteract.ts index b598093034..9eefe3b472 100644 --- a/libs/react-client/src/useChatInteract.ts +++ b/libs/react-client/src/useChatInteract.ts @@ -90,6 +90,13 @@ const useChatInteract = () => { [session?.socket] ); + const windowMessage = useCallback( + (data: any) => { + session?.socket.emit('window_message', data); + }, + [session?.socket] + ); + const startAudioStream = useCallback(() => { session?.socket.emit('audio_start'); }, [session?.socket]); @@ -186,6 +193,7 @@ const useChatInteract = () => { replyMessage, sendMessage, editMessage, + windowMessage, startAudioStream, sendAudioChunk, endAudioStream, diff --git a/libs/react-client/src/useChatSession.ts b/libs/react-client/src/useChatSession.ts index 9020847d2a..384a8f31fd 100644 --- a/libs/react-client/src/useChatSession.ts +++ b/libs/react-client/src/useChatSession.ts @@ -359,6 +359,12 @@ const useChatSession = () => { socket.on('token_usage', (count: number) => { setTokenCount((old) => old + count); }); + + socket.on('window_message', (data: any) => { + if (window.parent) { + window.parent.postMessage(data, '*'); + } + }); }, [setSession, sessionId, chatProfile] );