diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py
index b90f162f07..18ee6be8db 100644
--- a/backend/chainlit/config.py
+++ b/backend/chainlit/config.py
@@ -311,6 +311,8 @@ class CodeSettings:
@dataclass()
class ProjectSettings(DataClassJsonMixin):
allow_origins: List[str] = Field(default_factory=lambda: ["*"])
+ # Socket.io client transports option
+ transports: Optional[List[str]] = None
enable_telemetry: bool = True
# List of environment variables to be provided by each user to use the app. If empty, no environment variables will be asked to the user.
user_env: Optional[List[str]] = None
diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py
index 5118f544a7..7aeabe5329 100644
--- a/backend/chainlit/server.py
+++ b/backend/chainlit/server.py
@@ -301,7 +301,10 @@ def get_html_template():
"""
- js = f""""""
+ js = f""""""
css = None
if config.ui.custom_css:
diff --git a/backend/chainlit/socket.py b/backend/chainlit/socket.py
index d79c76c16e..5053262e2f 100644
--- a/backend/chainlit/socket.py
+++ b/backend/chainlit/socket.py
@@ -1,7 +1,6 @@
import asyncio
import json
import time
-import uuid
from typing import Any, Dict, Literal
from urllib.parse import unquote
@@ -77,24 +76,8 @@ def load_user_env(user_env):
return user_env
-def build_anon_user_identifier(environ):
- scope = environ.get("asgi.scope", {})
- client_ip, _ = scope.get("client")
- ip = environ.get("HTTP_X_FORWARDED_FOR", client_ip)
-
- try:
- headers = scope.get("headers", {})
- user_agent = next(
- (v.decode("utf-8") for k, v in headers if k.decode("utf-8") == "user-agent")
- )
- return str(uuid.uuid5(uuid.NAMESPACE_DNS, user_agent + ip))
-
- except StopIteration:
- return str(uuid.uuid5(uuid.NAMESPACE_DNS, ip))
-
-
@sio.on("connect")
-async def connect(sid, environ):
+async def connect(sid, environ, auth):
if (
not config.code.on_chat_start
and not config.code.on_message
@@ -110,8 +93,8 @@ async def connect(sid, environ):
try:
# Check if the authentication is required
if login_required:
- authorization_header = environ.get("HTTP_AUTHORIZATION")
- token = authorization_header.split(" ")[1] if authorization_header else None
+ token = auth.get("token")
+ token = token.split(" ")[1] if token else None
user = await get_current_user(token=token)
except Exception:
logger.info("Authentication failed")
@@ -125,16 +108,16 @@ def emit_fn(event, data):
def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
return sio.call(event, data, timeout=timeout, to=sid)
- session_id = environ.get("HTTP_X_CHAINLIT_SESSION_ID")
+ session_id = auth.get("sessionId")
if restore_existing_session(sid, session_id, emit_fn, emit_call_fn):
return True
- user_env_string = environ.get("HTTP_USER_ENV")
+ user_env_string = auth.get("userEnv")
user_env = load_user_env(user_env_string)
- client_type = environ.get("HTTP_X_CHAINLIT_CLIENT_TYPE")
+ client_type = auth.get("clientType")
http_referer = environ.get("HTTP_REFERER")
- url_encoded_chat_profile = environ.get("HTTP_X_CHAINLIT_CHAT_PROFILE")
+ url_encoded_chat_profile = auth.get("chatProfile")
chat_profile = (
unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None
)
@@ -149,7 +132,7 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
user=user,
token=token,
chat_profile=chat_profile,
- thread_id=environ.get("HTTP_X_CHAINLIT_THREAD_ID"),
+ thread_id=auth.get("threadId"),
languages=environ.get("HTTP_ACCEPT_LANGUAGE"),
http_referer=http_referer,
)
@@ -162,13 +145,13 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
async def connection_successful(sid):
context = init_ws_context(sid)
- if context.session.restored:
- return
-
await context.emitter.task_end()
await context.emitter.clear("clear_ask")
await context.emitter.clear("clear_call_fn")
+ if context.session.restored:
+ return
+
if context.session.thread_id_to_resume and config.code.on_chat_resume:
thread = await resume_thread(context.session)
if thread:
@@ -312,17 +295,13 @@ async def message(sid, payload: MessagePayload):
async def window_message(sid, data):
"""Handle a message send by the host window."""
session = WebsocketSession.require(sid)
- context = init_ws_context(session)
-
- await context.emitter.task_start()
+ init_ws_context(session)
if config.code.on_window_message:
try:
await config.code.on_window_message(data)
except asyncio.CancelledError:
pass
- finally:
- await context.emitter.task_end()
@sio.on("audio_start")
diff --git a/cypress/e2e/copilot/.chainlit/config.toml b/cypress/e2e/copilot/.chainlit/config.toml
index e2a93af08f..9c42755715 100644
--- a/cypress/e2e/copilot/.chainlit/config.toml
+++ b/cypress/e2e/copilot/.chainlit/config.toml
@@ -13,7 +13,7 @@ session_timeout = 3600
cache = false
# Authorized origins
-allow_origins = ["*"]
+allow_origins = ["http://127.0.0.1:8000"]
# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
# follow_symlink = false
diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx
index cc80e03ac9..9238ca2519 100644
--- a/frontend/src/App.tsx
+++ b/frontend/src/App.tsx
@@ -42,6 +42,7 @@ declare global {
light?: ThemOverride;
dark?: ThemOverride;
};
+ transports?: string[]
}
}
@@ -99,6 +100,7 @@ function App() {
return;
} else {
connect({
+ transports: window.transports,
userEnv,
accessToken
});
diff --git a/libs/copilot/src/chat/index.tsx b/libs/copilot/src/chat/index.tsx
index 5f0a0779e7..3cc4bd3289 100644
--- a/libs/copilot/src/chat/index.tsx
+++ b/libs/copilot/src/chat/index.tsx
@@ -12,6 +12,7 @@ export default function ChatWrapper() {
useEffect(() => {
if (session?.socket?.connected) return;
connect({
+ transports: window.transports,
userEnv: {},
accessToken: `Bearer ${accessToken}`
});
diff --git a/libs/react-client/src/useChatSession.ts b/libs/react-client/src/useChatSession.ts
index 441e66d665..b1079179f0 100644
--- a/libs/react-client/src/useChatSession.ts
+++ b/libs/react-client/src/useChatSession.ts
@@ -78,16 +78,18 @@ const useChatSession = () => {
// Use currentThreadId as thread id in websocket header
useEffect(() => {
if (session?.socket) {
- session.socket.io.opts.extraHeaders!['X-Chainlit-Thread-Id'] =
+ session.socket.auth["threadId"] =
currentThreadId || '';
}
}, [currentThreadId]);
const _connect = useCallback(
({
+ transports,
userEnv,
accessToken
}: {
+ transports?: string[]
userEnv: Record;
accessToken?: string;
}) => {
@@ -100,16 +102,17 @@ const useChatSession = () => {
const socket = io(uri, {
path,
- extraHeaders: {
- Authorization: accessToken || '',
- 'X-Chainlit-Client-Type': client.type,
- 'X-Chainlit-Session-Id': sessionId,
- 'X-Chainlit-Thread-Id': idToResume || '',
- 'user-env': JSON.stringify(userEnv),
- 'X-Chainlit-Chat-Profile': chatProfile
- ? encodeURIComponent(chatProfile)
- : ''
- }
+ withCredentials: true,
+ transports,
+ auth: {
+ token: accessToken,
+ clientType: client.type,
+ sessionId,
+ threadId: idToResume || '',
+ userEnv: JSON.stringify(userEnv),
+ chatProfile: chatProfile ? encodeURIComponent(chatProfile) : ''
+ }
+
});
setSession((old) => {
old?.socket?.removeAllListeners();