diff --git a/CHANGES.md b/CHANGES.md index d72b9749d..4fa952b35 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,7 +4,9 @@ ### Added +* Add ability to override ItemCollectionUri and SearchGetRequest models ([#271](https://github.com/stac-utils/stac-fastapi/pull/271)) * Added `collections` attribute to list of default fields to include, so that we satisfy the STAC API spec, which requires a `collections` attribute to be output when an item is part of a collection + ### Removed ### Changed diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index d2966314a..cabc0f451 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -73,6 +73,8 @@ class StacApi: stac_version: str = attr.ib(default=STAC_VERSION) description: str = attr.ib(default="stac-fastapi") search_request_model: Type[Search] = attr.ib(default=STACSearch) + search_get_request: Type[SearchGetRequest] = attr.ib(default=SearchGetRequest) + item_collection_uri: Type[ItemCollectionUri] = attr.ib(default=ItemCollectionUri) response_class: Type[Response] = attr.ib(default=JSONResponse) middlewares: List = attr.ib(default=attr.Factory(lambda: [BrotliMiddleware])) @@ -199,7 +201,9 @@ def register_get_search(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=self._create_endpoint(self.client.get_search, SearchGetRequest), + endpoint=self._create_endpoint( + self.client.get_search, self.search_get_request + ), ) def register_get_collections(self): @@ -255,7 +259,7 @@ def register_get_item_collection(self): response_model_exclude_none=True, methods=["GET"], endpoint=self._create_endpoint( - self.client.item_collection, ItemCollectionUri + self.client.item_collection, self.item_collection_uri ), ) diff --git a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py index 68719c035..b1d001025 100644 --- a/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/stac_fastapi/pgstac/core.py @@ -1,7 +1,7 @@ """Item crud client.""" import re from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union from urllib.parse import urljoin import attr @@ -27,6 +27,8 @@ class CoreCrudClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" + search_request_model: Type[PgstacSearch] = attr.ib(init=False, default=PgstacSearch) + async def all_collections(self, **kwargs) -> Collections: """Read all collections from the database.""" request: Request = kwargs["request"] @@ -168,7 +170,7 @@ async def _search_base( return collection async def item_collection( - self, id: str, limit: int = 10, token: str = None, **kwargs + self, id: str, limit: Optional[int] = None, token: str = None, **kwargs ) -> ItemCollection: """Get all items from a specific collection. @@ -185,7 +187,7 @@ async def item_collection( # If collection does not exist, NotFoundError wil be raised await self.get_collection(id, **kwargs) - req = PgstacSearch(collections=[id], limit=limit, token=token) + req = self.search_request_model(collections=[id], limit=limit, token=token) item_collection = await self._search_base(req, **kwargs) links = await CollectionLinks( collection_id=id, request=kwargs["request"] @@ -207,7 +209,9 @@ async def get_item(self, item_id: str, collection_id: str, **kwargs) -> Item: # If collection does not exist, NotFoundError wil be raised await self.get_collection(collection_id, **kwargs) - req = PgstacSearch(ids=[item_id], collections=[collection_id], limit=1) + req = self.search_request_model( + ids=[item_id], collections=[collection_id], limit=1 + ) item_collection = await self._search_base(req, **kwargs) if not item_collection["features"]: raise NotFoundError( @@ -238,7 +242,7 @@ async def get_search( ids: Optional[List[str]] = None, bbox: Optional[List[NumType]] = None, datetime: Optional[Union[str, datetime]] = None, - limit: Optional[int] = 10, + limit: Optional[int] = None, query: Optional[str] = None, token: Optional[str] = None, fields: Optional[List[str]] = None, @@ -292,7 +296,7 @@ async def get_search( # Do the request try: - search_request = PgstacSearch(**base_args) + search_request = self.search_request_model(**base_args) except ValidationError: raise HTTPException(status_code=400, detail="Invalid parameters provided") return await self.post_search(search_request, request=kwargs["request"])