diff --git a/infrastructure/main.tf b/infrastructure/main.tf index 7892e39a..57b70d86 100644 --- a/infrastructure/main.tf +++ b/infrastructure/main.tf @@ -161,6 +161,7 @@ resource "aws_lambda_function" "workflows_api_handler" { INGEST_URL = var.ingest_url RASTER_URL = var.raster_url STAC_URL = var.stac_url + MWAA_ENV = module.mwaa.airflow_url } } } diff --git a/workflows_api/runtime/src/auth.py b/workflows_api/runtime/src/auth.py index e495630f..68be7ae3 100644 --- a/workflows_api/runtime/src/auth.py +++ b/workflows_api/runtime/src/auth.py @@ -56,7 +56,7 @@ def decode_token( claims.setdefault("aud", claims["client_id"]) claims.validate() - return claims + return claims, token except errors.JoseError: # logger.exception("Unable to decode token") raise HTTPException(status_code=403, detail="Bad auth token") @@ -65,6 +65,9 @@ def decode_token( def get_username(claims: security.HTTPBasicCredentials = Depends(decode_token)): return claims["sub"] +def get_and_validate_token(token: security.HTTPAuthorizationCredentials = Depends(token_scheme)): + decode_token(token) + return token def _get_secret_hash(username: str, client_id: str, client_secret: str) -> str: # A keyed-hash message authentication code (HMAC) calculated using diff --git a/workflows_api/runtime/src/collection_publisher.py b/workflows_api/runtime/src/collection_publisher.py index d5aa43df..20be420c 100644 --- a/workflows_api/runtime/src/collection_publisher.py +++ b/workflows_api/runtime/src/collection_publisher.py @@ -1,6 +1,7 @@ import os from typing import Optional, Union +import requests import fsspec import xarray as xr import xstac @@ -12,34 +13,27 @@ ZarrDataset, ) from src.validators import get_s3_credentials -from stac_pydantic import Item class CollectionPublisher: - def ingest(self, collection: DashboardCollection): + def ingest(self, collection: DashboardCollection, token: str, ingest_api: str): """ Takes a collection model, does necessary preprocessing, and loads into the PgSTAC collection table """ - # TODO get service auth token - creds = get_db_credentials(os.environ["DB_SECRET_ARN"]) - collection = [convert_decimals_to_float(collection.dict(by_alias=True))] - # TODO reroute to API - pass - -class ItemPublisher: - def ingest(self, item: Item): - """ - Takes an item model, - does necessary preprocessing, - and loads into the PgSTAC item table - """ - creds = get_db_credentials(os.environ["DB_SECRET_ARN"]) - item = [convert_decimals_to_float(item.dict(by_alias=True))] - with PgstacDB(dsn=creds.dsn_string, debug=True) as db: - load_into_pgstac(db=db, ingestions=item, table=IngestionType.items) + collection = collection.model_dump(by_alias=True) + url = f"{ingest_api}/collections" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + response = requests.post(url, json=collection, headers=headers) + if response.status_code == 200: + print("Success:", response.json()) + else: + print("Error:", response.status_code, response.text) # TODO refactor class Publisher: @@ -161,25 +155,22 @@ def generate_stac( create_function = self.func_map.get(data_type, self.create_cog_collection) return create_function(dataset) - def ingest(self, collection: DashboardCollection): + def ingest(self, collection: DashboardCollection, token: str, ingest_api: str): """ Takes a collection model, does necessary preprocessing, and loads into the PgSTAC collection table """ - db_creds = self._get_db_credentials() - # TODO reroute to ingest API - collection = [convert_decimals_to_float(collection.dict(by_alias=True))] - with PgstacDB(dsn=db_creds.dsn_string, debug=True) as db: - load_into_pgstac( - db=db, ingestions=collection, table=IngestionType.collections - ) + collection = collection.model_dump(by_alias=True) + + url = f"{ingest_api}/collections" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + response = requests.post(url, json=collection, headers=headers) + if response.status_code == 200: + print("Success:", response.json()) + else: + print("Error:", response.status_code, response.text) - def delete(self, collection_id: str): - """ - Deletes the collection from the database - """ - db_creds = self._get_db_credentials() - with PgstacDB(dsn=db_creds.dsn_string, debug=True) as db: - loader = VEDALoader(db=db) - loader.delete_collection(collection_id) diff --git a/workflows_api/runtime/src/config.py b/workflows_api/runtime/src/config.py index 0eac1ea5..fd0556b6 100644 --- a/workflows_api/runtime/src/config.py +++ b/workflows_api/runtime/src/config.py @@ -19,6 +19,7 @@ class Settings(BaseSettings): ingest_url: str = Field(description="URL of ingest API") raster_url: str = Field(description="URL of raster API") stac_url: str = Field(description="URL of STAC API") + mwaa_env: str = Field(description="MWAA URL") class Config(): env_file = ".env" diff --git a/workflows_api/runtime/src/dependencies.py b/workflows_api/runtime/src/dependencies.py deleted file mode 100644 index 9b015d9c..00000000 --- a/workflows_api/runtime/src/dependencies.py +++ /dev/null @@ -1,12 +0,0 @@ -import logging - -import boto3 -import src.auth as auth -import src.config as config -import src.services as services - -from fastapi import Depends, HTTPException, security - -logger = logging.getLogger(__name__) - -token_scheme = security.HTTPBearer() diff --git a/workflows_api/runtime/src/main.py b/workflows_api/runtime/src/main.py index 60cd8ba0..d9ac142d 100644 --- a/workflows_api/runtime/src/main.py +++ b/workflows_api/runtime/src/main.py @@ -64,14 +64,15 @@ def validate_dataset(dataset: schemas.COGDataset): "/dataset/publish", tags=["Dataset"], dependencies=[Depends(auth.get_username)] ) async def publish_dataset( + token = Depends(auth.get_and_validate_token), dataset: Union[schemas.ZarrDataset, schemas.COGDataset] = Body( ..., discriminator="data_type" - ) + ), ): # Construct and load collection collection_data = publisher.generate_stac(dataset, dataset.data_type or "cog") collection = schemas.DashboardCollection.parse_obj(collection_data) - collection_publisher.ingest(collection) + collection_publisher.ingest(collection, token, settings.ingest_url) # TODO improve typing return_dict = {