Skip to content

Commit

Permalink
Test and resolve security vulnerability with get_file and upload_file (
Browse files Browse the repository at this point in the history
…Chainlit#1441)

* Unit tests for `get_file` and `upload_file` endpoints, including authorization.
* Add auth to /project/file get endpoint by @qvalentin , closes Chainlit#1101.

---------

Co-authored-by: qvalentin <[email protected]>
  • Loading branch information
dokterbob and qvalentin authored Oct 17, 2024
1 parent beb44ca commit e65f191
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 23 deletions.
24 changes: 17 additions & 7 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
APIRouter,
Depends,
FastAPI,
File,
Form,
HTTPException,
Query,
Expand Down Expand Up @@ -839,11 +840,9 @@ async def delete_thread(

@router.post("/project/file")
async def upload_file(
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
session_id: str,
file: UploadFile,
current_user: Annotated[
Union[None, User, PersistedUser], Depends(get_current_user)
],
):
"""Upload a file to the session files directory."""

Expand All @@ -868,17 +867,21 @@ async def upload_file(

content = await file.read()

assert file.filename, "No filename for uploaded file"
assert file.content_type, "No content type for uploaded file"

file_response = await session.persist_file(
name=file.filename, content=content, mime=file.content_type
)

return JSONResponse(file_response)
return JSONResponse(content=file_response)


@router.get("/project/file/{file_id}")
async def get_file(
file_id: str,
session_id: Optional[str] = None,
session_id: str,
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
):
"""Get a file from the session files directory."""

Expand All @@ -888,10 +891,17 @@ async def get_file(

if not session:
raise HTTPException(
status_code=404,
detail="Session not found",
status_code=401,
detail="Unauthorized",
)

if current_user:
if not session.user or session.user.identifier != current_user.identifier:
raise HTTPException(
status_code=401,
detail="You are not authorized to download files from this session",
)

if file_id in session.files:
file = session.files[file_id]
return FileResponse(file["path"], media_type=file["type"])
Expand Down
39 changes: 25 additions & 14 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from contextlib import asynccontextmanager
from typing import Callable
from unittest.mock import AsyncMock, Mock

import pytest
Expand All @@ -20,20 +21,30 @@ def persisted_test_user():


@pytest.fixture
def mock_session():
mock = Mock(spec=WebsocketSession)
mock.id = "test_session_id"
mock.user_env = {"test_env": "value"}
mock.chat_settings = {}
mock.chat_profile = None
mock.http_referer = None
mock.client_type = "webapp"
mock.languages = ["en"]
mock.thread_id = "test_thread_id"
mock.emit = AsyncMock()
mock.has_first_interaction = True

return mock
def mock_session_factory(persisted_test_user: PersistedUser) -> Callable[..., Mock]:
def create_mock_session(**kwargs) -> Mock:
mock = Mock(spec=WebsocketSession)
mock.user = kwargs.get("user", persisted_test_user)
mock.id = kwargs.get("id", "test_session_id")
mock.user_env = kwargs.get("user_env", {"test_env": "value"})
mock.chat_settings = kwargs.get("chat_settings", {})
mock.chat_profile = kwargs.get("chat_profile", None)
mock.http_referer = kwargs.get("http_referer", None)
mock.client_type = kwargs.get("client_type", "webapp")
mock.languages = kwargs.get("languages", ["en"])
mock.thread_id = kwargs.get("thread_id", "test_thread_id")
mock.emit = AsyncMock()
mock.has_first_interaction = kwargs.get("has_first_interaction", True)
mock.files = kwargs.get("files", {})

return mock

return create_mock_session


@pytest.fixture
def mock_session(mock_session_factory) -> Mock:
return mock_session_factory()


@asynccontextmanager
Expand Down
Loading

0 comments on commit e65f191

Please sign in to comment.