Skip to content

Commit

Permalink
Check whether path is in appropriate folder.
Browse files Browse the repository at this point in the history
  • Loading branch information
dokterbob committed Sep 9, 2024
1 parent 6a5644a commit 51baf5b
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 8 deletions.
8 changes: 8 additions & 0 deletions backend/chainlit/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Util functions which are explicitly not part of the public API."""

from pathlib import Path


def is_path_inside(child_path: Path, parent_path: Path) -> bool:
"""Check if the child path is inside the parent path."""
return parent_path.resolve() in child_path.resolve().parents
15 changes: 12 additions & 3 deletions backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from pydantic.dataclasses import Field, dataclass
from starlette.datastructures import Headers

from ._utils import is_path_inside

if TYPE_CHECKING:
from chainlit.action import Action
from chainlit.element import ElementBased
Expand Down Expand Up @@ -353,18 +355,25 @@ def load_translation(self, language: str):
config_translation_dir, f"{default_language}.json"
)

if os.path.exists(translation_lib_file_path):
if is_path_inside(
Path(translation_lib_file_path), Path(config_translation_dir)
) and os.path.exists(translation_lib_file_path):
with open(translation_lib_file_path, "r", encoding="utf-8") as f:
translation = json.load(f)
elif os.path.exists(translation_lib_parent_language_file_path):
elif is_path_inside(
Path(translation_lib_parent_language_file_path),
Path(config_translation_dir),
) and os.path.exists(translation_lib_parent_language_file_path):
logger.warning(
f"Translation file for {language} not found. Using parent translation {parent_language}."
)
with open(
translation_lib_parent_language_file_path, "r", encoding="utf-8"
) as f:
translation = json.load(f)
elif os.path.exists(default_translation_lib_file_path):
elif is_path_inside(
Path(default_translation_lib_file_path), Path(config_translation_dir)
) and os.path.exists(default_translation_lib_file_path):
logger.warning(
f"Translation file for {language} not found. Using default translation {default_language}."
)
Expand Down
6 changes: 5 additions & 1 deletion backend/chainlit/markdown.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from pathlib import Path

from chainlit.logger import logger
from ._utils import is_path_inside

# Default chainlit.md file created if none exists
DEFAULT_MARKDOWN_STR = """# Welcome to Chainlit! 🚀🤖
Expand Down Expand Up @@ -35,7 +37,9 @@ def get_markdown_str(root: str, language: str):
translated_chainlit_md_path = os.path.join(root, f"chainlit_{language}.md")
default_chainlit_md_path = os.path.join(root, "chainlit.md")

if os.path.exists(translated_chainlit_md_path):
if is_path_inside(Path(translated_chainlit_md_path), Path(root)) and os.path.exists(
translated_chainlit_md_path
):
chainlit_md_path = translated_chainlit_md_path
else:
chainlit_md_path = default_chainlit_md_path
Expand Down
17 changes: 13 additions & 4 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
from typing_extensions import Annotated
from watchfiles import awatch

from ._utils import is_path_inside

mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")

Expand Down Expand Up @@ -666,6 +668,7 @@ async def project_settings(
"""Return project settings. This is called by the UI before the establishing the websocket connection."""

# Load the markdown file based on the provided language

markdown = get_markdown_str(config.root, language)

profiles = []
Expand Down Expand Up @@ -897,8 +900,7 @@ async def serve_file(
base_path = Path(config.project.local_fs_path).resolve()
file_path = (base_path / filename).resolve()

# Check if the base path is a parent of the file path
if base_path not in file_path.parents:
if not is_path_inside(file_path, base_path):
raise HTTPException(status_code=400, detail="Invalid filename")

if file_path.is_file():
Expand Down Expand Up @@ -941,6 +943,7 @@ async def get_logo(theme: Optional[Theme] = Query(Theme.light)):

if not logo_path:
raise HTTPException(status_code=404, detail="Missing default logo")

media_type, _ = mimetypes.guess_type(logo_path)

return FileResponse(logo_path, media_type=media_type)
Expand All @@ -954,13 +957,19 @@ async def get_avatar(avatar_id: str):

avatar_id = avatar_id.strip().lower().replace(" ", "_")

avatar_path = os.path.join(APP_ROOT, "public", "avatars", f"{avatar_id}.*")
base_path = os.path.join(APP_ROOT, "public", "avatars")
avatar_glob = os.path.join(base_path, f"{avatar_id}.*")

files = glob.glob(avatar_path)
files = glob.glob(avatar_glob)

if files:
avatar_path = files[0]

if not is_path_inside(Path(avatar_path), Path(base_path)):
raise HTTPException(status_code=400, detail="Invalid filename")

media_type, _ = mimetypes.guess_type(avatar_path)

return FileResponse(avatar_path, media_type=media_type)
else:
return await get_favicon()
Expand Down

0 comments on commit 51baf5b

Please sign in to comment.