Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: datetime and collection id validation fixes as improve collection extent method #81

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
13 changes: 13 additions & 0 deletions .github/workflows/cicd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@ jobs:
env:
AWS_DEFAULT_REGION: us-west-2

services:
pgstac:
image: ghcr.io/stac-utils/pgstac:v0.7.10
env:
POSTGRES_USER: username
POSTGRES_PASSWORD: password
POSTGRES_DB: postgis
PGUSER: username
PGPASSWORD: password
PGDATABASE: postgis
ports:
- 5432:5432

steps:
- uses: actions/checkout@v3

Expand Down
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ This codebase utilizes the [Pydantic SSM Settings](https://github.com/developmen
2. Install dependencies:

```
pip install -r api/requirements.txt
pip install -r requirements.txt -r api/requirements.txt
```

3. Run API:
Expand Down Expand Up @@ -66,6 +66,20 @@ This script is also available at `scripts/sync_env.sh`, which can be invoked wit
. scripts/sync_env.sh stac-ingestor-env-secret-<stage>
```

## Testing

```shell
pytest
```

Some tests require a locally-running **pgstac** database, and will be skipped if there isn't one at `postgresql://username:password@localhost:5432/postgis`.
To run the **pgstac** tests:

```shell
docker compose up -d
pytest
docker compose down
```

## License

Expand Down
1 change: 1 addition & 0 deletions api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ psycopg[binary,pool]>=3.0.15
pydantic_ssm_settings>=0.2.0
pydantic>=1.9.0,<2
pypgstac==0.7.10
pystac[jsonschema]>=1.8.4
python-multipart==0.0.5
requests>=2.27.1
s3fs==2023.3.0
Expand Down
22 changes: 16 additions & 6 deletions api/src/collection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from typing import Union
from typing import Optional, Union

import fsspec
import xarray as xr
import xstac
from pypgstac.db import PgstacDB

from src.schemas import (
COGDataset,
DashboardCollection,
Expand All @@ -13,6 +14,7 @@
ZarrDataset,
)
from src.utils import (
DbCreds,
IngestionType,
convert_decimals_to_float,
get_db_credentials,
Expand Down Expand Up @@ -40,8 +42,10 @@ class Publisher:
"type": "Collection",
"stac_version": "1.0.0",
}
db_creds: Optional[DbCreds]

def __init__(self) -> None:
def __init__(self, db_creds: Optional[DbCreds] = None) -> None:
self.db_creds = db_creds
self.func_map = {
DataType.zarr: self.create_zarr_collection,
DataType.cog: self.create_cog_collection,
Expand Down Expand Up @@ -147,9 +151,9 @@ def ingest(self, collection: DashboardCollection):
does necessary preprocessing,
and loads into the PgSTAC collection table
"""
creds = get_db_credentials(os.environ["DB_SECRET_ARN"])
db_creds = self._get_db_credentials()
collection = [convert_decimals_to_float(collection.dict(by_alias=True))]
with PgstacDB(dsn=creds.dsn_string, debug=True) as db:
with PgstacDB(dsn=db_creds.dsn_string, debug=True) as db:
load_into_pgstac(
db=db, ingestions=collection, table=IngestionType.collections
)
Expand All @@ -158,7 +162,13 @@ def delete(self, collection_id: str):
"""
Deletes the collection from the database
"""
creds = get_db_credentials(os.environ["DB_SECRET_ARN"])
with PgstacDB(dsn=creds.dsn_string, debug=True) as db:
db_creds = self._get_db_credentials()
with PgstacDB(dsn=db_creds.dsn_string, debug=True) as db:
loader = VEDALoader(db=db)
loader.delete_collection(collection_id)

def _get_db_credentials(self) -> DbCreds:
if self.db_creds:
return self.db_creds
else:
return get_db_credentials(os.environ["DB_SECRET_ARN"])
5 changes: 5 additions & 0 deletions api/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class Settings(BaseSettings):
client_id: str = Field(description="The Cognito APP client ID")
client_secret: str = Field(description="The Cognito APP client secret")

path_prefix: Optional[str] = Field(
"",
description="Optional path prefix to add to all api endpoints",
)

class Config(AwsSsmSourceConfig):
env_file = ".env"

Expand Down
33 changes: 18 additions & 15 deletions api/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import src.helpers as helpers
import src.schemas as schemas
import src.services as services
from fastapi import Body, Depends, FastAPI, HTTPException
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.security import OAuth2PasswordRequestForm
Expand Down Expand Up @@ -44,12 +44,12 @@
},
contact={"url": "https://github.com/NASA-IMPACT/veda-stac-ingestor"},
)
app.router.route_class = LoggerRouteHandler
api_router = APIRouter(prefix=settings.path_prefix, route_class=LoggerRouteHandler)

publisher = collection_loader.Publisher()


@app.get(
@api_router.get(
"/ingestions", response_model=schemas.ListIngestionResponse, tags=["Ingestion"]
)
async def list_ingestions(
Expand All @@ -64,7 +64,7 @@ async def list_ingestions(
)


@app.post(
@api_router.post(
"/ingestions",
response_model=schemas.Ingestion,
tags=["Ingestion"],
Expand All @@ -86,7 +86,7 @@ async def create_ingestion(
).enqueue(db)


@app.get(
@api_router.get(
"/ingestions/{ingestion_id}",
response_model=schemas.Ingestion,
tags=["Ingestion"],
Expand All @@ -100,7 +100,7 @@ def get_ingestion(
return ingestion


@app.patch(
@api_router.patch(
"/ingestions/{ingestion_id}",
response_model=schemas.Ingestion,
tags=["Ingestion"],
Expand All @@ -117,7 +117,7 @@ def update_ingestion(
return updated_item.save(db)


@app.delete(
@api_router.delete(
"/ingestions/{ingestion_id}",
response_model=schemas.Ingestion,
tags=["Ingestion"],
Expand All @@ -139,7 +139,7 @@ def cancel_ingestion(
return ingestion.cancel(db)


@app.post(
@api_router.post(
"/collections",
tags=["Collection"],
status_code=201,
Expand All @@ -160,7 +160,7 @@ def publish_collection(collection: schemas.DashboardCollection):
)


@app.delete(
@api_router.delete(
"/collections/{collection_id}",
tags=["Collection"],
dependencies=[Depends(auth.get_username)],
Expand All @@ -177,7 +177,7 @@ def delete_collection(collection_id: str):
raise HTTPException(status_code=400, detail=(f"{e}"))


@app.post(
@api_router.post(
"/workflow-executions",
response_model=schemas.WorkflowExecutionResponse,
tags=["Workflow-Executions"],
Expand All @@ -194,7 +194,7 @@ async def start_workflow_execution(
return helpers.trigger_discover(input)


@app.get(
@api_router.get(
"/workflow-executions/{workflow_execution_id}",
response_model=Union[schemas.ExecutionResponse, schemas.BaseResponse],
tags=["Workflow-Executions"],
Expand All @@ -208,7 +208,7 @@ async def get_workflow_execution_status(
return helpers.get_status(workflow_execution_id)


@app.post("/token", tags=["Auth"], response_model=schemas.AuthResponse)
@api_router.post("/token", tags=["Auth"], response_model=schemas.AuthResponse)
async def get_token(
form_data: OAuth2PasswordRequestForm = Depends(),
) -> Dict:
Expand All @@ -224,7 +224,7 @@ async def get_token(
)


@app.post(
@api_router.post(
"/dataset/validate",
tags=["Dataset"],
dependencies=[Depends(auth.get_username)],
Expand All @@ -250,7 +250,7 @@ def validate_dataset(dataset: schemas.COGDataset):
}


@app.post(
@api_router.post(
"/dataset/publish", tags=["Dataset"], dependencies=[Depends(auth.get_username)]
)
async def publish_dataset(
Expand Down Expand Up @@ -280,14 +280,17 @@ async def publish_dataset(
return return_dict


@app.get("/auth/me", tags=["Auth"], response_model=schemas.WhoAmIResponse)
@api_router.get("/auth/me", tags=["Auth"], response_model=schemas.WhoAmIResponse)
def who_am_i(claims=Depends(auth.decode_token)):
"""
Return claims for the provided JWT
"""
return claims


app.include_router(api_router)


# exception handling
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
Expand Down
3 changes: 2 additions & 1 deletion api/src/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aws_lambda_powertools.metrics import MetricUnit # noqa: F401
from fastapi import Request, Response
from fastapi.routing import APIRoute

from src.config import Settings

settings = Settings()
Expand Down Expand Up @@ -40,7 +41,7 @@ async def route_handler(request: Request) -> Response:
value=1,
)
tracer.put_annotation(key="path", value=request.url.path)
tracer.capture_method(original_route_handler)(request)
await tracer.capture_method(original_route_handler)(request)
return await original_route_handler(request)

return route_handler
2 changes: 1 addition & 1 deletion api/src/schema_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TemporalExtent(BaseModel):
startdate: datetime
enddate: datetime

@root_validator
@root_validator(pre=True)
def check_dates(cls, v):
if v["startdate"] >= v["enddate"]:
raise ValueError("Invalid extent - startdate must be before enddate")
Expand Down
15 changes: 11 additions & 4 deletions api/src/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ class CmrInput(WorkflowInputBase):


# allows the construction of models with a list of discriminated unions
ItemUnion = Annotated[Union[S3Input, CmrInput], Field(discriminator="discovery")]
ItemUnion = Annotated[
Union[S3Input, CmrInput], Field(discriminator="discovery") # noqa
]


class Dataset(BaseModel):
Expand All @@ -300,9 +302,14 @@ class Dataset(BaseModel):
@validator("collection")
def check_id(cls, collection):
if not re.match(r"[a-z]+(?:-[a-z]+)*", collection):
raise ValueError(
"Invalid id - id must be all lowercase, with optional '-' delimiters"
)
# allow collection id to "break the rules" if an already-existing collection matches
try:
validators.collection_exists(collection_id=collection)
except ValueError:
# overwrite error - the issue isn't the non-existing function, it's the new id
raise ValueError(
"Invalid id - id must be all lowercase, with optional '-' delimiters"
)
return collection

@root_validator
Expand Down
13 changes: 1 addition & 12 deletions api/src/vedaloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,7 @@ def update_collection_summaries(self, collection_id: str) -> None:
logger.info(f"Updating extents for collection: {collection_id}.")
cur.execute(
"""
UPDATE collections SET
content = content ||
jsonb_build_object(
'extent', jsonb_build_object(
'spatial', jsonb_build_object(
'bbox', collection_bbox(collections.id)
),
'temporal', jsonb_build_object(
'interval', collection_temporal_extent(collections.id)
)
)
)
UPDATE collections set content = content || pgstac.collection_extent(collections.id)
WHERE collections.id=%s;
""",
(collection_id,),
Expand Down
Loading
Loading