diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 124a466..902f158 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -16,6 +16,19 @@ jobs: env: AWS_DEFAULT_REGION: us-west-2 + services: + pgstac: + image: ghcr.io/stac-utils/pgstac:v0.7.10 + env: + POSTGRES_USER: username + POSTGRES_PASSWORD: password + POSTGRES_DB: postgis + PGUSER: username + PGPASSWORD: password + PGDATABASE: postgis + ports: + - 5432:5432 + steps: - uses: actions/checkout@v3 diff --git a/README.md b/README.md index df51367..edc1373 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ This codebase utilizes the [Pydantic SSM Settings](https://github.com/developmen 2. Install dependencies: ``` - pip install -r api/requirements.txt + pip install -r requirements.txt -r api/requirements.txt ``` 3. Run API: @@ -66,6 +66,20 @@ This script is also available at `scripts/sync_env.sh`, which can be invoked wit . scripts/sync_env.sh stac-ingestor-env-secret- ``` +## Testing + +```shell +pytest +``` + +Some tests require a locally-running **pgstac** database, and will be skipped if there isn't one at `postgresql://username:password@localhost:5432/postgis`. +To run the **pgstac** tests: + +```shell +docker compose up -d +pytest +docker compose down +``` ## License diff --git a/api/requirements.txt b/api/requirements.txt index c6828dc..832cbbd 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -10,6 +10,7 @@ psycopg[binary,pool]>=3.0.15 pydantic_ssm_settings>=0.2.0 pydantic>=1.9.0,<2 pypgstac==0.7.10 +pystac[jsonschema]>=1.8.4 python-multipart==0.0.5 requests>=2.27.1 s3fs==2023.3.0 diff --git a/api/src/collection.py b/api/src/collection.py index 1e24e67..611a137 100644 --- a/api/src/collection.py +++ b/api/src/collection.py @@ -1,10 +1,11 @@ import os -from typing import Union +from typing import Optional, Union import fsspec import xarray as xr import xstac from pypgstac.db import PgstacDB + from src.schemas import ( COGDataset, DashboardCollection, @@ -13,6 +14,7 @@ ZarrDataset, ) from src.utils import ( + DbCreds, IngestionType, convert_decimals_to_float, get_db_credentials, @@ -40,8 +42,10 @@ class Publisher: "type": "Collection", "stac_version": "1.0.0", } + db_creds: Optional[DbCreds] - def __init__(self) -> None: + def __init__(self, db_creds: Optional[DbCreds] = None) -> None: + self.db_creds = db_creds self.func_map = { DataType.zarr: self.create_zarr_collection, DataType.cog: self.create_cog_collection, @@ -147,9 +151,9 @@ def ingest(self, collection: DashboardCollection): does necessary preprocessing, and loads into the PgSTAC collection table """ - creds = get_db_credentials(os.environ["DB_SECRET_ARN"]) + db_creds = self._get_db_credentials() collection = [convert_decimals_to_float(collection.dict(by_alias=True))] - with PgstacDB(dsn=creds.dsn_string, debug=True) as db: + with PgstacDB(dsn=db_creds.dsn_string, debug=True) as db: load_into_pgstac( db=db, ingestions=collection, table=IngestionType.collections ) @@ -158,7 +162,13 @@ def delete(self, collection_id: str): """ Deletes the collection from the database """ - creds = get_db_credentials(os.environ["DB_SECRET_ARN"]) - with PgstacDB(dsn=creds.dsn_string, debug=True) as db: + db_creds = self._get_db_credentials() + with PgstacDB(dsn=db_creds.dsn_string, debug=True) as db: loader = VEDALoader(db=db) loader.delete_collection(collection_id) + + def _get_db_credentials(self) -> DbCreds: + if self.db_creds: + return self.db_creds + else: + return get_db_credentials(os.environ["DB_SECRET_ARN"]) diff --git a/api/src/config.py b/api/src/config.py index 7ac857f..a9913b6 100644 --- a/api/src/config.py +++ b/api/src/config.py @@ -32,6 +32,11 @@ class Settings(BaseSettings): client_id: str = Field(description="The Cognito APP client ID") client_secret: str = Field(description="The Cognito APP client secret") + path_prefix: Optional[str] = Field( + "", + description="Optional path prefix to add to all api endpoints", + ) + class Config(AwsSsmSourceConfig): env_file = ".env" diff --git a/api/src/main.py b/api/src/main.py index da118fd..9a40973 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -11,7 +11,7 @@ import src.helpers as helpers import src.schemas as schemas import src.services as services -from fastapi import Body, Depends, FastAPI, HTTPException +from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from fastapi.security import OAuth2PasswordRequestForm @@ -44,12 +44,12 @@ }, contact={"url": "https://github.com/NASA-IMPACT/veda-stac-ingestor"}, ) -app.router.route_class = LoggerRouteHandler +api_router = APIRouter(prefix=settings.path_prefix, route_class=LoggerRouteHandler) publisher = collection_loader.Publisher() -@app.get( +@api_router.get( "/ingestions", response_model=schemas.ListIngestionResponse, tags=["Ingestion"] ) async def list_ingestions( @@ -64,7 +64,7 @@ async def list_ingestions( ) -@app.post( +@api_router.post( "/ingestions", response_model=schemas.Ingestion, tags=["Ingestion"], @@ -86,7 +86,7 @@ async def create_ingestion( ).enqueue(db) -@app.get( +@api_router.get( "/ingestions/{ingestion_id}", response_model=schemas.Ingestion, tags=["Ingestion"], @@ -100,7 +100,7 @@ def get_ingestion( return ingestion -@app.patch( +@api_router.patch( "/ingestions/{ingestion_id}", response_model=schemas.Ingestion, tags=["Ingestion"], @@ -117,7 +117,7 @@ def update_ingestion( return updated_item.save(db) -@app.delete( +@api_router.delete( "/ingestions/{ingestion_id}", response_model=schemas.Ingestion, tags=["Ingestion"], @@ -139,7 +139,7 @@ def cancel_ingestion( return ingestion.cancel(db) -@app.post( +@api_router.post( "/collections", tags=["Collection"], status_code=201, @@ -160,7 +160,7 @@ def publish_collection(collection: schemas.DashboardCollection): ) -@app.delete( +@api_router.delete( "/collections/{collection_id}", tags=["Collection"], dependencies=[Depends(auth.get_username)], @@ -177,7 +177,7 @@ def delete_collection(collection_id: str): raise HTTPException(status_code=400, detail=(f"{e}")) -@app.post( +@api_router.post( "/workflow-executions", response_model=schemas.WorkflowExecutionResponse, tags=["Workflow-Executions"], @@ -194,7 +194,7 @@ async def start_workflow_execution( return helpers.trigger_discover(input) -@app.get( +@api_router.get( "/workflow-executions/{workflow_execution_id}", response_model=Union[schemas.ExecutionResponse, schemas.BaseResponse], tags=["Workflow-Executions"], @@ -208,7 +208,7 @@ async def get_workflow_execution_status( return helpers.get_status(workflow_execution_id) -@app.post("/token", tags=["Auth"], response_model=schemas.AuthResponse) +@api_router.post("/token", tags=["Auth"], response_model=schemas.AuthResponse) async def get_token( form_data: OAuth2PasswordRequestForm = Depends(), ) -> Dict: @@ -224,7 +224,7 @@ async def get_token( ) -@app.post( +@api_router.post( "/dataset/validate", tags=["Dataset"], dependencies=[Depends(auth.get_username)], @@ -250,7 +250,7 @@ def validate_dataset(dataset: schemas.COGDataset): } -@app.post( +@api_router.post( "/dataset/publish", tags=["Dataset"], dependencies=[Depends(auth.get_username)] ) async def publish_dataset( @@ -280,7 +280,7 @@ async def publish_dataset( return return_dict -@app.get("/auth/me", tags=["Auth"], response_model=schemas.WhoAmIResponse) +@api_router.get("/auth/me", tags=["Auth"], response_model=schemas.WhoAmIResponse) def who_am_i(claims=Depends(auth.decode_token)): """ Return claims for the provided JWT @@ -288,6 +288,9 @@ def who_am_i(claims=Depends(auth.decode_token)): return claims +app.include_router(api_router) + + # exception handling @app.exception_handler(RequestValidationError) async def validation_exception_handler(request, exc): diff --git a/api/src/monitoring.py b/api/src/monitoring.py index 1b16980..e40ab52 100644 --- a/api/src/monitoring.py +++ b/api/src/monitoring.py @@ -5,6 +5,7 @@ from aws_lambda_powertools.metrics import MetricUnit # noqa: F401 from fastapi import Request, Response from fastapi.routing import APIRoute + from src.config import Settings settings = Settings() @@ -40,7 +41,7 @@ async def route_handler(request: Request) -> Response: value=1, ) tracer.put_annotation(key="path", value=request.url.path) - tracer.capture_method(original_route_handler)(request) + await tracer.capture_method(original_route_handler)(request) return await original_route_handler(request) return route_handler diff --git a/api/src/schema_helpers.py b/api/src/schema_helpers.py index d66a10a..0621220 100644 --- a/api/src/schema_helpers.py +++ b/api/src/schema_helpers.py @@ -49,7 +49,7 @@ class TemporalExtent(BaseModel): startdate: datetime enddate: datetime - @root_validator + @root_validator(pre=True) def check_dates(cls, v): if v["startdate"] >= v["enddate"]: raise ValueError("Invalid extent - startdate must be before enddate") diff --git a/api/src/schemas.py b/api/src/schemas.py index d2c577b..4de7b11 100644 --- a/api/src/schemas.py +++ b/api/src/schemas.py @@ -283,7 +283,9 @@ class CmrInput(WorkflowInputBase): # allows the construction of models with a list of discriminated unions -ItemUnion = Annotated[Union[S3Input, CmrInput], Field(discriminator="discovery")] +ItemUnion = Annotated[ + Union[S3Input, CmrInput], Field(discriminator="discovery") # noqa +] class Dataset(BaseModel): @@ -300,9 +302,14 @@ class Dataset(BaseModel): @validator("collection") def check_id(cls, collection): if not re.match(r"[a-z]+(?:-[a-z]+)*", collection): - raise ValueError( - "Invalid id - id must be all lowercase, with optional '-' delimiters" - ) + # allow collection id to "break the rules" if an already-existing collection matches + try: + validators.collection_exists(collection_id=collection) + except ValueError: + # overwrite error - the issue isn't the non-existing function, it's the new id + raise ValueError( + "Invalid id - id must be all lowercase, with optional '-' delimiters" + ) return collection @root_validator diff --git a/api/src/vedaloader.py b/api/src/vedaloader.py index 594b1af..99e0010 100644 --- a/api/src/vedaloader.py +++ b/api/src/vedaloader.py @@ -30,18 +30,7 @@ def update_collection_summaries(self, collection_id: str) -> None: logger.info(f"Updating extents for collection: {collection_id}.") cur.execute( """ - UPDATE collections SET - content = content || - jsonb_build_object( - 'extent', jsonb_build_object( - 'spatial', jsonb_build_object( - 'bbox', collection_bbox(collections.id) - ), - 'temporal', jsonb_build_object( - 'interval', collection_temporal_extent(collections.id) - ) - ) - ) + UPDATE collections set content = content || pgstac.collection_extent(collections.id) WHERE collections.id=%s; """, (collection_id,), diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 5fc3edc..64e8441 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -1,9 +1,15 @@ +import datetime import os +from typing import Generator import boto3 +import psycopg import pytest from fastapi.testclient import TestClient from moto import mock_dynamodb, mock_ssm +from pypgstac.db import PgstacDB +from pystac import Collection, Extent, SpatialExtent, TemporalExtent +from src.schemas import DashboardCollection from stac_pydantic import Item @@ -25,6 +31,7 @@ def test_environ(): os.environ["RASTER_URL"] = "https://test-raster.url" os.environ["USERPOOL_ID"] = "fake_id" os.environ["STAGE"] = "testing" + os.environ["PATH_PREFIX"] = "/api/ingest" @pytest.fixture @@ -144,6 +151,21 @@ def example_stac_item(): } +@pytest.fixture +def dashboard_collection() -> DashboardCollection: + collection = Collection( + "test-collection", + "A test collection", + Extent( + SpatialExtent( + [[-180, -90, 180, 90]], + ), + TemporalExtent([[datetime.datetime.utcnow(), None]]), + ), + ) + return DashboardCollection.parse_obj(collection.to_dict()) + + @pytest.fixture def example_ingestion(example_stac_item): from src import schemas @@ -154,3 +176,14 @@ def example_ingestion(example_stac_item): status=schemas.Status.queued, item=Item.parse_obj(example_stac_item), ) + + +@pytest.fixture +def pgstac() -> Generator[PgstacDB, None, None]: + dsn = "postgresql://username:password@localhost:5432/postgis" + try: + psycopg.connect(dsn) + except Exception: + pytest.skip(f"could not connect to pgstac database: {dsn}") + with PgstacDB(dsn, commit_on_exit=False) as db: + yield db diff --git a/api/tests/test_collection.py b/api/tests/test_collection.py new file mode 100644 index 0000000..ed02e2a --- /dev/null +++ b/api/tests/test_collection.py @@ -0,0 +1,32 @@ +import pytest +from pypgstac.db import PgstacDB +from pystac import Collection +from src.collection import Publisher +from src.schemas import DashboardCollection +from src.utils import DbCreds + + +@pytest.fixture +def publisher() -> Publisher: + return Publisher( + DbCreds( + username="username", + password="password", + host="localhost", + port=5432, + dbname="postgis", + engine="postgresql", + ) + ) + + +def test_ingest( + pgstac: PgstacDB, publisher: Publisher, dashboard_collection: DashboardCollection +) -> None: + publisher.ingest(dashboard_collection) + collection = Collection.from_dict( + pgstac.query_one( + r"SELECT * FROM pgstac.get_collection(%s)", [dashboard_collection.id] + ) + ) + collection.validate() diff --git a/api/tests/test_registration.py b/api/tests/test_registration.py index 53751dd..d6b10e6 100644 --- a/api/tests/test_registration.py +++ b/api/tests/test_registration.py @@ -11,7 +11,7 @@ from fastapi.testclient import TestClient from src import schemas, services -ingestion_endpoint = "/ingestions" +ingestion_endpoint = "/api/ingest/ingestions" class TestList: diff --git a/cdk/config.py b/cdk/config.py index 35bf0af..ba6a5ca 100644 --- a/cdk/config.py +++ b/cdk/config.py @@ -79,6 +79,11 @@ class Deployment(BaseSettings): description="ID of AWS ECR repository used for OIDC provider", ) + permissions_boundary_policy_name: Optional[str] = Field( + None, + description="Name of IAM policy to define stack permissions boundary", + ) + class Config: env_prefix = "" case_sentive = False diff --git a/cdk/permission_boundary.py b/cdk/permission_boundary.py new file mode 100644 index 0000000..487f381 --- /dev/null +++ b/cdk/permission_boundary.py @@ -0,0 +1,56 @@ +from typing import Union + +import jsii +from aws_cdk import IAspect, aws_iam +from constructs import IConstruct +from jsii._reference_map import _refs +from jsii._utils import Singleton + + +@jsii.implements(IAspect) +class PermissionBoundaryAspect: + """ + This aspect finds all aws_iam.Role objects in a node (ie. CDK stack) and sets permission boundary to the given ARN. + """ + + def __init__(self, permission_boundary: Union[aws_iam.ManagedPolicy, str]) -> None: + """ + :param permission_boundary: Either aws_iam.ManagedPolicy object or managed policy's ARN string + """ + self.permission_boundary = permission_boundary + + def visit(self, construct_ref: IConstruct) -> None: + """ + construct_ref only contains a string reference to an object. To get the actual object, + we need to resolve it using JSII mapping. + :param construct_ref: ObjRef object with string reference to the actual object. + :return: None + """ + if isinstance(construct_ref, jsii._kernel.ObjRef) and hasattr( + construct_ref, "ref" + ): + kernel = Singleton._instances[ + jsii._kernel.Kernel + ] # The same object is available as: jsii.kernel + resolve = _refs.resolve(kernel, construct_ref) + else: + resolve = construct_ref + + def _walk(obj): + if isinstance(obj, aws_iam.Role): + cfn_role = obj.node.find_child("Resource") + policy_arn = ( + self.permission_boundary + if isinstance(self.permission_boundary, str) + else self.permission_boundary.managed_policy_arn + ) + cfn_role.add_property_override("PermissionsBoundary", policy_arn) + else: + if hasattr(obj, "permissions_node"): + for c in obj.permissions_node.children: + _walk(c) + if hasattr(obj, "node") and obj.node.children: + for c in obj.node.children: + _walk(c) + + _walk(resolve) diff --git a/cdk/stack.py b/cdk/stack.py index 8259f8f..9c3507c 100644 --- a/cdk/stack.py +++ b/cdk/stack.py @@ -3,6 +3,7 @@ from typing import Dict from aws_cdk import ( + Aspects, Duration, RemovalPolicy, Stack, @@ -30,6 +31,19 @@ def __init__( **kwargs, ) -> None: super().__init__(scope, construct_id, **kwargs) + + if config.permissions_boundary_policy_name: + permission_boundary_policy = iam.ManagedPolicy.from_managed_policy_name( + self, + "permission-boundary", + config.permissions_boundary_policy_name, + ) + iam.PermissionsBoundary.of(self).apply(permission_boundary_policy) + + from cdk.permission_boundary import PermissionBoundaryAspect + + Aspects.of(self).add(PermissionBoundaryAspect(permission_boundary_policy)) + table = self.build_table() jwks_url = self.build_jwks_url(config.userpool_id) @@ -264,7 +278,7 @@ def build_api_lambda( vpc_subnets=ec2.SubnetSelection( subnet_type=ec2.SubnetType.PUBLIC if db_subnet_public - else ec2.SubnetType.PRIVATE_ISOLATED + else ec2.SubnetType.PRIVATE_WITH_NAT ), allow_public_subnet=True, memory_size=2048, @@ -330,7 +344,7 @@ def build_ingestor( vpc_subnets=ec2.SubnetSelection( subnet_type=ec2.SubnetType.PUBLIC if db_subnet_public - else ec2.SubnetType.PRIVATE_ISOLATED + else ec2.SubnetType.PRIVATE_WITH_NAT ), allow_public_subnet=True, memory_size=2048, diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..e1537c0 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,15 @@ +version: '3' +services: + database: + container_name: pgstac + image: ghcr.io/stac-utils/pgstac:v0.7.10 + environment: + - POSTGRES_USER=username + - POSTGRES_PASSWORD=password + - POSTGRES_DB=postgis + - PGUSER=username + - PGPASSWORD=password + - PGDATABASE=postgis + ports: + - "5432:5432" + command: postgres -N 500