diff --git a/stac_api/clients/base.py b/stac_api/clients/base.py index 06d04e0d9..77a4dfe15 100644 --- a/stac_api/clients/base.py +++ b/stac_api/clients/base.py @@ -4,11 +4,12 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Type, Union -from stac_api.api.extensions.extension import ApiExtension -from stac_api.models import schemas from stac_pydantic import ItemCollection from stac_pydantic.api import ConformanceClasses, LandingPage +from stac_api.api.extensions.extension import ApiExtension +from stac_api.models import schemas + NumType = Union[float, int] @@ -51,6 +52,22 @@ def delete_collection(self, id: str, **kwargs) -> schemas.Collection: ... +@dataclass # type: ignore +class BulkTransactionsClient(abc.ABC): + """bulk transactions client""" + + @staticmethod + def _chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + @abc.abstractmethod + def bulk_item_insert(self, items: List[Dict], chunks: Optional[int] = None) -> None: + """bulk item insertion, not implemented by default, and not exposed through the api""" + raise NotImplementedError + + @dataclass # type:ignore class BaseCoreClient(abc.ABC): """Base client for core endpoints defined by stac""" diff --git a/stac_api/clients/postgres/transactions.py b/stac_api/clients/postgres/transactions.py index 937ea8daf..a478fde5a 100644 --- a/stac_api/clients/postgres/transactions.py +++ b/stac_api/clients/postgres/transactions.py @@ -1,11 +1,14 @@ """transactions extension client""" +import json import logging from dataclasses import dataclass -from typing import Type, Union +from typing import Dict, List, Optional, Type, Union + +from sqlalchemy import create_engine from stac_api import errors -from stac_api.clients.base import BaseTransactionsClient +from stac_api.clients.base import BaseTransactionsClient, BulkTransactionsClient from stac_api.clients.postgres.base import PostgresClient from stac_api.models import database, schemas @@ -110,3 +113,42 @@ def delete_collection(self, id: str, **kwargs) -> schemas.Collection: obj = self._delete(id, table=self.collection_table) obj.base_url = str(kwargs["request"].base_url) return schemas.Collection.from_orm(obj) + + +@dataclass +class PostgresBulkTransactions(BulkTransactionsClient): + """postgres bulk transactions""" + + connection_str: str + debug: bool = False + + def __post_init__(self): + """create sqlalchemy engine""" + self.engine = create_engine(self.connection_str, echo=self.debug) + + @staticmethod + def _preprocess_item(item) -> Dict: + """ + preprocess items to match data model + # TODO: dedup with GetterDict logic (ref #58) + """ + item["geometry"] = json.dumps(item["geometry"]) + item["collection_id"] = item.pop("collection") + item["datetime"] = item["properties"].pop("datetime") + return item + + def bulk_item_insert( + self, items: List[Dict], chunk_size: Optional[int] = None + ) -> None: + """ + bulk item insertion using sqlalchemy core + https://docs.sqlalchemy.org/en/13/faq/performance.html#i-m-inserting-400-000-rows-with-the-orm-and-it-s-really-slow + """ + items = [self._preprocess_item(item) for item in items] + if chunk_size: + for chunk in self._chunks(items, chunk_size): + self.engine.execute(database.Item.__table__.insert(), chunk) + return + + self.engine.execute(database.Item.__table__.insert(), items) + return diff --git a/tests/clients/test_postgres.py b/tests/clients/test_postgres.py index bb908018a..2755fc3de 100644 --- a/tests/clients/test_postgres.py +++ b/tests/clients/test_postgres.py @@ -4,7 +4,10 @@ import pytest from stac_api.clients.postgres.core import CoreCrudClient -from stac_api.clients.postgres.transactions import TransactionsClient +from stac_api.clients.postgres.transactions import ( + PostgresBulkTransactions, + TransactionsClient, +) from stac_api.errors import ConflictError, NotFoundError from stac_api.models.schemas import Collection, Item from tests.conftest import MockStarletteRequest @@ -161,3 +164,47 @@ def test_delete_item( with pytest.raises(NotFoundError): postgres_core.get_item(item.id, request=MockStarletteRequest) + + +def test_bulk_item_insert( + postgres_transactions: TransactionsClient, + postgres_bulk_transactions: PostgresBulkTransactions, + load_test_data: Callable, +): + coll = Collection.parse_obj(load_test_data("test_collection.json")) + postgres_transactions.create_collection(coll, request=MockStarletteRequest) + + item = Item.parse_obj(load_test_data("test_item.json")) + + items = [] + for _ in range(10): + _item = item.dict() + _item["id"] = str(uuid.uuid4()) + items.append(_item) + + postgres_bulk_transactions.bulk_item_insert(items) + + for item in items: + postgres_transactions.delete_item(item["id"], request=MockStarletteRequest) + + +def test_bulk_item_insert_chunked( + postgres_transactions: TransactionsClient, + postgres_bulk_transactions: PostgresBulkTransactions, + load_test_data: Callable, +): + coll = Collection.parse_obj(load_test_data("test_collection.json")) + postgres_transactions.create_collection(coll, request=MockStarletteRequest) + + item = Item.parse_obj(load_test_data("test_item.json")) + + items = [] + for _ in range(10): + _item = item.dict() + _item["id"] = str(uuid.uuid4()) + items.append(_item) + + postgres_bulk_transactions.bulk_item_insert(items, chunk_size=2) + + for item in items: + postgres_transactions.delete_item(item["id"], request=MockStarletteRequest) diff --git a/tests/conftest.py b/tests/conftest.py index 047a31455..bb490115a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,11 +3,11 @@ from typing import Callable, Dict, Generator from unittest.mock import PropertyMock, patch +import pytest from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker from starlette.testclient import TestClient -import pytest from stac_api.api.app import StacApi from stac_api.api.extensions import ( ContextExtension, @@ -18,7 +18,10 @@ ) from stac_api.clients.postgres.core import CoreCrudClient from stac_api.clients.postgres.tokens import PaginationTokenClient -from stac_api.clients.postgres.transactions import TransactionsClient +from stac_api.clients.postgres.transactions import ( + PostgresBulkTransactions, + TransactionsClient, +) from stac_api.config import ApiSettings, inject_settings from stac_api.models.schemas import Collection @@ -67,23 +70,30 @@ class MockStarletteRequest: @pytest.fixture -def reader_connection() -> Generator[Session, None, None]: - """Create a reader connection""" +def sqlalchemy_engine(): engine = create_engine(settings.reader_connection_string) - db_session = sessionmaker(autocommit=False, autoflush=False, bind=engine)() + yield engine + engine.dispose() + + +@pytest.fixture +def reader_connection(sqlalchemy_engine) -> Generator[Session, None, None]: + """Create a reader connection""" + db_session = sessionmaker( + autocommit=False, autoflush=False, bind=sqlalchemy_engine + )() yield db_session db_session.close() - engine.dispose() @pytest.fixture -def writer_connection() -> Generator[Session, None, None]: +def writer_connection(sqlalchemy_engine) -> Generator[Session, None, None]: """Create a writer connection""" - engine = create_engine(settings.writer_connection_string) - db_session = sessionmaker(autocommit=False, autoflush=False, bind=engine)() + db_session = sessionmaker( + autocommit=False, autoflush=False, bind=sqlalchemy_engine + )() yield db_session db_session.close() - engine.dispose() @pytest.fixture @@ -118,6 +128,12 @@ def postgres_transactions(reader_connection, writer_connection): yield client +@pytest.fixture +def postgres_bulk_transactions(): + client = PostgresBulkTransactions(connection_str=settings.writer_connection_string) + return client + + @pytest.fixture def api_client(): return StacApi(