Skip to content

Commit

Permalink
add bulk transactions client (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
geospatial-jeff authored Jan 12, 2021
1 parent 016b121 commit e93c9b1
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 15 deletions.
21 changes: 19 additions & 2 deletions stac_api/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Type, Union

from stac_api.api.extensions.extension import ApiExtension
from stac_api.models import schemas
from stac_pydantic import ItemCollection
from stac_pydantic.api import ConformanceClasses, LandingPage

from stac_api.api.extensions.extension import ApiExtension
from stac_api.models import schemas

NumType = Union[float, int]


Expand Down Expand Up @@ -51,6 +52,22 @@ def delete_collection(self, id: str, **kwargs) -> schemas.Collection:
...


@dataclass # type: ignore
class BulkTransactionsClient(abc.ABC):
"""bulk transactions client"""

@staticmethod
def _chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]

@abc.abstractmethod
def bulk_item_insert(self, items: List[Dict], chunks: Optional[int] = None) -> None:
"""bulk item insertion, not implemented by default, and not exposed through the api"""
raise NotImplementedError


@dataclass # type:ignore
class BaseCoreClient(abc.ABC):
"""Base client for core endpoints defined by stac"""
Expand Down
46 changes: 44 additions & 2 deletions stac_api/clients/postgres/transactions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""transactions extension client"""

import json
import logging
from dataclasses import dataclass
from typing import Type, Union
from typing import Dict, List, Optional, Type, Union

from sqlalchemy import create_engine

from stac_api import errors
from stac_api.clients.base import BaseTransactionsClient
from stac_api.clients.base import BaseTransactionsClient, BulkTransactionsClient
from stac_api.clients.postgres.base import PostgresClient
from stac_api.models import database, schemas

Expand Down Expand Up @@ -110,3 +113,42 @@ def delete_collection(self, id: str, **kwargs) -> schemas.Collection:
obj = self._delete(id, table=self.collection_table)
obj.base_url = str(kwargs["request"].base_url)
return schemas.Collection.from_orm(obj)


@dataclass
class PostgresBulkTransactions(BulkTransactionsClient):
"""postgres bulk transactions"""

connection_str: str
debug: bool = False

def __post_init__(self):
"""create sqlalchemy engine"""
self.engine = create_engine(self.connection_str, echo=self.debug)

@staticmethod
def _preprocess_item(item) -> Dict:
"""
preprocess items to match data model
# TODO: dedup with GetterDict logic (ref #58)
"""
item["geometry"] = json.dumps(item["geometry"])
item["collection_id"] = item.pop("collection")
item["datetime"] = item["properties"].pop("datetime")
return item

def bulk_item_insert(
self, items: List[Dict], chunk_size: Optional[int] = None
) -> None:
"""
bulk item insertion using sqlalchemy core
https://docs.sqlalchemy.org/en/13/faq/performance.html#i-m-inserting-400-000-rows-with-the-orm-and-it-s-really-slow
"""
items = [self._preprocess_item(item) for item in items]
if chunk_size:
for chunk in self._chunks(items, chunk_size):
self.engine.execute(database.Item.__table__.insert(), chunk)
return

self.engine.execute(database.Item.__table__.insert(), items)
return
49 changes: 48 additions & 1 deletion tests/clients/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import pytest

from stac_api.clients.postgres.core import CoreCrudClient
from stac_api.clients.postgres.transactions import TransactionsClient
from stac_api.clients.postgres.transactions import (
PostgresBulkTransactions,
TransactionsClient,
)
from stac_api.errors import ConflictError, NotFoundError
from stac_api.models.schemas import Collection, Item
from tests.conftest import MockStarletteRequest
Expand Down Expand Up @@ -161,3 +164,47 @@ def test_delete_item(

with pytest.raises(NotFoundError):
postgres_core.get_item(item.id, request=MockStarletteRequest)


def test_bulk_item_insert(
postgres_transactions: TransactionsClient,
postgres_bulk_transactions: PostgresBulkTransactions,
load_test_data: Callable,
):
coll = Collection.parse_obj(load_test_data("test_collection.json"))
postgres_transactions.create_collection(coll, request=MockStarletteRequest)

item = Item.parse_obj(load_test_data("test_item.json"))

items = []
for _ in range(10):
_item = item.dict()
_item["id"] = str(uuid.uuid4())
items.append(_item)

postgres_bulk_transactions.bulk_item_insert(items)

for item in items:
postgres_transactions.delete_item(item["id"], request=MockStarletteRequest)


def test_bulk_item_insert_chunked(
postgres_transactions: TransactionsClient,
postgres_bulk_transactions: PostgresBulkTransactions,
load_test_data: Callable,
):
coll = Collection.parse_obj(load_test_data("test_collection.json"))
postgres_transactions.create_collection(coll, request=MockStarletteRequest)

item = Item.parse_obj(load_test_data("test_item.json"))

items = []
for _ in range(10):
_item = item.dict()
_item["id"] = str(uuid.uuid4())
items.append(_item)

postgres_bulk_transactions.bulk_item_insert(items, chunk_size=2)

for item in items:
postgres_transactions.delete_item(item["id"], request=MockStarletteRequest)
36 changes: 26 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from typing import Callable, Dict, Generator
from unittest.mock import PropertyMock, patch

import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from starlette.testclient import TestClient

import pytest
from stac_api.api.app import StacApi
from stac_api.api.extensions import (
ContextExtension,
Expand All @@ -18,7 +18,10 @@
)
from stac_api.clients.postgres.core import CoreCrudClient
from stac_api.clients.postgres.tokens import PaginationTokenClient
from stac_api.clients.postgres.transactions import TransactionsClient
from stac_api.clients.postgres.transactions import (
PostgresBulkTransactions,
TransactionsClient,
)
from stac_api.config import ApiSettings, inject_settings
from stac_api.models.schemas import Collection

Expand Down Expand Up @@ -67,23 +70,30 @@ class MockStarletteRequest:


@pytest.fixture
def reader_connection() -> Generator[Session, None, None]:
"""Create a reader connection"""
def sqlalchemy_engine():
engine = create_engine(settings.reader_connection_string)
db_session = sessionmaker(autocommit=False, autoflush=False, bind=engine)()
yield engine
engine.dispose()


@pytest.fixture
def reader_connection(sqlalchemy_engine) -> Generator[Session, None, None]:
"""Create a reader connection"""
db_session = sessionmaker(
autocommit=False, autoflush=False, bind=sqlalchemy_engine
)()
yield db_session
db_session.close()
engine.dispose()


@pytest.fixture
def writer_connection() -> Generator[Session, None, None]:
def writer_connection(sqlalchemy_engine) -> Generator[Session, None, None]:
"""Create a writer connection"""
engine = create_engine(settings.writer_connection_string)
db_session = sessionmaker(autocommit=False, autoflush=False, bind=engine)()
db_session = sessionmaker(
autocommit=False, autoflush=False, bind=sqlalchemy_engine
)()
yield db_session
db_session.close()
engine.dispose()


@pytest.fixture
Expand Down Expand Up @@ -118,6 +128,12 @@ def postgres_transactions(reader_connection, writer_connection):
yield client


@pytest.fixture
def postgres_bulk_transactions():
client = PostgresBulkTransactions(connection_str=settings.writer_connection_string)
return client


@pytest.fixture
def api_client():
return StacApi(
Expand Down

0 comments on commit e93c9b1

Please sign in to comment.