diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index 8992f859ac4..249bcae167d 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -69,6 +69,33 @@ def _uuid(uuid_str: str) -> UUID: raise InvalidUUIDError(f"Could not parse {uuid_str} as a UUID") +class ChromaAPIRouter(fastapi.APIRouter): + # A simple subclass of fastapi's APIRouter which treats URLs with a trailing "/" the + # same as URLs without. Docs will only contain URLs without trailing "/"s. + def add_api_route(self, path: str, *args: Any, **kwargs: Any) -> None: + # If kwargs["include_in_schema"] isn't passed OR is True, we should only + # include the non-"/" path. If kwargs["include_in_schema"] is False, include + # neither. + exclude_from_schema = ( + "include_in_schema" in kwargs and not kwargs["include_in_schema"] + ) + + def include_in_schema(path: str) -> bool: + nonlocal exclude_from_schema + return not exclude_from_schema and not path.endswith("/") + + kwargs["include_in_schema"] = include_in_schema(path) + super().add_api_route(path, *args, **kwargs) + + if path.endswith("/"): + path = path[:-1] + else: + path = path + "/" + + kwargs["include_in_schema"] = include_in_schema(path) + super().add_api_route(path, *args, **kwargs) + + class FastAPI(chromadb.server.Server): def __init__(self, settings: Settings): super().__init__(settings) @@ -84,7 +111,7 @@ def __init__(self, settings: Settings): allow_methods=["*"], ) - self.router = fastapi.APIRouter() + self.router = ChromaAPIRouter() self.router.add_api_route("/api/v1", self.root, methods=["GET"]) self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"])