diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b77ca7de..20baafc3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,7 +8,8 @@ on: - main env: - DB_PORT: 3306 + MYSQL_PORT: 3306 + POSTGRESQL_PORT: 5432 DB_USER: root DB_PASSWORD: root @@ -30,6 +31,20 @@ jobs: defaults: run: shell: bash + + services: + postgres: + image: postgres + env: + POSTGRES_USER: ${{ env.DB_USER }} + POSTGRES_PASSWORD: ${{ env.DB_PASSWORD }} + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 steps: - uses: actions/checkout@v2 @@ -76,8 +91,9 @@ jobs: - name: Run tests run: | - MYSQL_SRV="${{ env.DB_USER }}:${{ env.DB_PASSWORD }}@127.0.0.1:${{ env.DB_PORT }}" - python -m pytest . --color=yes --cov=terracotta --mysql-server=$MYSQL_SRV + MYSQL_SRV="${{ env.DB_USER }}:${{ env.DB_PASSWORD }}@127.0.0.1:${{ env.MYSQL_PORT }}" + POSTGRESQL_SRV="${{ env.DB_USER }}:${{ env.DB_PASSWORD }}@localhost:${{ env.POSTGRESQL_PORT }}" + python -m pytest . --color=yes --cov=terracotta --mysql-server=$MYSQL_SRV --postgresql-server=$POSTGRESQL_SRV - name: Run benchmarks run: | diff --git a/setup.py b/setup.py index bf34bf66..7afbb20c 100644 --- a/setup.py +++ b/setup.py @@ -90,18 +90,21 @@ 'matplotlib', 'moto', 'aws-xray-sdk', - 'pymysql>=1.0.0' + 'pymysql>=1.0.0', + 'psycopg2' ], 'docs': [ 'sphinx', 'sphinx_autodoc_typehints', 'sphinx-click', - 'pymysql>=1.0.0' + 'pymysql>=1.0.0', + 'psycopg2' ], 'recommended': [ 'colorlog', 'crick', - 'pymysql>=1.0.0' + 'pymysql>=1.0.0', + 'psycopg2' ] }, # CLI diff --git a/terracotta/config.py b/terracotta/config.py index ee44f563..86f0c66f 100644 --- a/terracotta/config.py +++ b/terracotta/config.py @@ -16,7 +16,7 @@ class TerracottaSettings(NamedTuple): #: Path to database DRIVER_PATH: str = '' - #: Driver provider to use (sqlite, sqlite-remote, mysql; auto-detected by default) + #: Driver provider to use (sqlite, sqlite-remote, mysql, postgresql; auto-detected by default) DRIVER_PROVIDER: Optional[str] = None #: Activate debug mode in Flask app @@ -73,6 +73,12 @@ class TerracottaSettings(NamedTuple): #: MySQL database password (if not given in driver path) MYSQL_PASSWORD: Optional[str] = None + #: PostgreSQL database username (if not given in driver path) + POSTGRESQL_USER: Optional[str] = None + + #: PostgreSQL database password (if not given in driver path) + POSTGRESQL_PASSWORD: Optional[str] = None + #: Use a process pool for band retrieval in parallel USE_MULTIPROCESSING: bool = True @@ -125,6 +131,8 @@ class SettingSchema(Schema): MYSQL_USER = fields.String() MYSQL_PASSWORD = fields.String() + POSTGRESQL_USER = fields.String() + POSTGRESQL_PASSWORD = fields.String() USE_MULTIPROCESSING = fields.Boolean() diff --git a/terracotta/drivers/__init__.py b/terracotta/drivers/__init__.py index 26f38988..5799ade4 100644 --- a/terracotta/drivers/__init__.py +++ b/terracotta/drivers/__init__.py @@ -24,6 +24,10 @@ def load_driver(provider: str) -> Type[MetaStore]: from terracotta.drivers.mysql_meta_store import MySQLMetaStore return MySQLMetaStore + if provider == 'postgresql': + from terracotta.drivers.postgresql_meta_store import PostgreSQLMetaStore + return PostgreSQLMetaStore + if provider == 'sqlite': from terracotta.drivers.sqlite_meta_store import SQLiteMetaStore return SQLiteMetaStore @@ -41,6 +45,9 @@ def auto_detect_provider(url_or_path: str) -> str: if scheme == 'mysql': return 'mysql' + if scheme == 'postgresql': + return 'postgresql' + return 'sqlite' @@ -61,7 +68,7 @@ def get_driver(url_or_path: URLOrPathType, provider: str = None) -> TerracottaDr url_or_path: A path identifying the database to connect to. The expected format depends on the driver provider. - provider: Driver provider to use (one of sqlite, sqlite-remote, mysql; + provider: Driver provider to use (one of sqlite, sqlite-remote, mysql, postgresql; default: auto-detect). Example: diff --git a/terracotta/drivers/postgresql_meta_store.py b/terracotta/drivers/postgresql_meta_store.py new file mode 100644 index 00000000..c0bd907f --- /dev/null +++ b/terracotta/drivers/postgresql_meta_store.py @@ -0,0 +1,83 @@ +"""drivers/postgresql_meta_store.py + +PostgreSQL-backed metadata driver. Metadata is stored in a PostgreSQL database. +""" + +from typing import Mapping, Sequence + +import sqlalchemy as sqla +from terracotta.drivers.relational_meta_store import RelationalMetaStore + + +class PostgreSQLMetaStore(RelationalMetaStore): + """A PostgreSQL-backed metadata driver. + + Stores metadata and paths to raster files in PostgreSQL. + + Requires a running PostgreSQL server. + + The PostgreSQL database consists of 4 different tables: + + - ``terracotta``: Metadata about the database itself. + - ``key_names``: Contains two columns holding all available keys and their description. + - ``datasets``: Maps key values to physical raster path. + - ``metadata``: Contains actual metadata as separate columns. Indexed via key values. + + This driver caches key names. + """ + SQL_DIALECT = 'postgresql' + SQL_DRIVER = 'psycopg2' + SQL_TIMEOUT_KEY = 'connect_timeout' + + MAX_PRIMARY_KEY_SIZE = 2730 // 4 # Max B-tree index size in bytes + DEFAULT_PORT = 5432 + # Will connect to this db before creatting the 'terracotta' db + DEFAULT_CONNECT_DB = 'postgres' + + def __init__(self, postgresql_path: str) -> None: + """Initialize the PostgreSQLDriver. + + This should not be called directly, use :func:`~terracotta.get_driver` instead. + + Arguments: + + postgresql_path: URL to running PostgreSQL server, in the form + ``postgresql://username:password@hostname/database`` + + """ + super().__init__(postgresql_path) + + # raise an exception if database name is invalid + if not self.url.database: + raise ValueError('database must be specified in PostgreSQL path') + if '/' in self.url.database.strip('/'): + raise ValueError('invalid database path') + + @classmethod + def _normalize_path(cls, path: str) -> str: + url = cls._parse_path(path) + + path = f'{url.drivername}://{url.host}:{url.port or cls.DEFAULT_PORT}/{url.database}' + path = path.rstrip('/') + return path + + def _create_database(self) -> None: + engine = sqla.create_engine( + # `.set()` returns a copy with changed parameters + self.url.set(database=self.DEFAULT_CONNECT_DB), + echo=False, + future=True, + isolation_level='AUTOCOMMIT' + ) + with engine.connect() as connection: + connection.execute(sqla.text(f'CREATE DATABASE {self.url.database}')) + connection.commit() + + def _initialize_database( + self, + keys: Sequence[str], + key_descriptions: Mapping[str, str] = None + ) -> None: + # Enforce max primary key length equal to max B-tree index size + self.SQL_KEY_SIZE = self.MAX_PRIMARY_KEY_SIZE // len(keys) + super()._initialize_database(keys, key_descriptions) diff --git a/terracotta/drivers/relational_meta_store.py b/terracotta/drivers/relational_meta_store.py index 2ccdf45f..8e3e5a39 100644 --- a/terracotta/drivers/relational_meta_store.py +++ b/terracotta/drivers/relational_meta_store.py @@ -185,7 +185,7 @@ def db_version(self) -> str: def create(self, keys: Sequence[str], key_descriptions: Mapping[str, str] = None) -> None: """Create and initialize database with empty tables. - This must be called before opening the first connection. The MySQL database must not + This must be called before opening the first connection. The database must not exist already. Arguments: diff --git a/tests/conftest.py b/tests/conftest.py index 44a85ffe..a32b8349 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,6 +39,10 @@ def pytest_addoption(parser): '--mysql-server', help='MySQL server to use for testing in the form of user:password@host:port' ) + parser.addoption( + '--postgresql-server', + help='PostgreSQL server to use for testing in the form of user:password@host:port' + ) @pytest.fixture() @@ -46,6 +50,11 @@ def mysql_server(request): return request.config.getoption('mysql_server') +@pytest.fixture() +def postgresql_server(request): + return request.config.getoption('postgresql_server') + + def cloud_optimize(raster_file, outfile, create_mask=False, remove_nodata=False): import math import contextlib @@ -372,58 +381,70 @@ def test_server(testdb): @pytest.fixture() -def driver_path(provider, tmpdir, mysql_server): +def driver_path(provider, tmpdir, mysql_server, postgresql_server): """Get a valid, uninitialized driver path for given provider""" import random import string + from terracotta import drivers + from terracotta.exceptions import InvalidDatabaseError + import sqlalchemy as sqla from urllib.parse import urlparse - def validate_con_info(con_info): - return (con_info.scheme == 'mysql' + def validate_con_info(con_info, db_scheme): + return (con_info.scheme == db_scheme and con_info.hostname and con_info.username and not con_info.path) def random_string(length): - return ''.join(random.choices(string.ascii_uppercase, k=length)) + return ''.join(random.choices(string.ascii_lowercase, k=length)) if provider == 'sqlite': dbfile = tmpdir.join('test.sqlite') yield str(dbfile) - elif provider == 'mysql': - if not mysql_server: - return pytest.skip('mysql_server argument not given') + elif provider == 'mysql' or provider == 'postgresql': - if not mysql_server.startswith('mysql://'): - mysql_server = f'mysql://{mysql_server}' + if provider == 'mysql': + db_server = mysql_server + con_db = 'mysql' + elif provider == 'postgresql': + db_server = postgresql_server + con_db = 'postgres' - con_info = urlparse(mysql_server) - if not validate_con_info(con_info): - raise ValueError('invalid value for mysql_server') + if not db_server: + return pytest.skip(f'{provider}_server argument not given') - dbpath = random_string(24) + if not db_server.startswith(f'{provider}://'): + db_server = f'{provider}://{db_server}' - import pymysql + con_info = urlparse(db_server) + if not validate_con_info(con_info, provider): + raise ValueError(f'invalid value for {provider}_server') + + driver = drivers.get_driver(f'{db_server}/{con_db}', provider=provider) try: - with pymysql.connect(host=con_info.hostname, user=con_info.username, - password=con_info.password): + with driver.connect(verify=False): pass - except pymysql.OperationalError as exc: - raise RuntimeError('error connecting to MySQL server') from exc + except InvalidDatabaseError as exc: + raise RuntimeError(f'error connecting to {provider} server') from exc + dbpath = random_string(24) try: - yield f'{mysql_server}/{dbpath}' + yield f'{db_server}/{dbpath}' finally: # cleanup - with pymysql.connect(host=con_info.hostname, user=con_info.username, - password=con_info.password) as connection: - with connection.cursor() as cursor: - try: - cursor.execute(f'DROP DATABASE IF EXISTS {dbpath}') - except pymysql.Warning: - pass + with driver.meta_store.sqla_engine.connect().execution_options( + isolation_level='AUTOCOMMIT' + ) as conn: + if provider == 'postgresql': + # Postgres refuses to drop DB if any sessions are hanging around + conn.execute(sqla.text( + "SELECT pg_terminate_backend(pg_stat_activity.pid) " + "FROM pg_stat_activity " + f"WHERE pg_stat_activity.datname = '{dbpath}'")) + conn.execute(sqla.text(f'DROP DATABASE IF EXISTS {dbpath}')) else: return NotImplementedError(f'unknown provider {provider}') diff --git a/tests/drivers/test_drivers.py b/tests/drivers/test_drivers.py index d3881839..a584833f 100644 --- a/tests/drivers/test_drivers.py +++ b/tests/drivers/test_drivers.py @@ -1,10 +1,11 @@ import pytest -TESTABLE_DRIVERS = ['sqlite', 'mysql'] +TESTABLE_DRIVERS = ['sqlite', 'mysql', 'postgresql'] DRIVER_CLASSES = { 'sqlite': 'SQLiteMetaStore', 'sqlite-remote': 'SQLiteRemoteMetaStore', - 'mysql': 'MySQLMetaStore' + 'mysql': 'MySQLMetaStore', + 'postgresql': 'PostgreSQLMetaStore' } diff --git a/tests/drivers/test_mysql.py b/tests/drivers/test_paths.py similarity index 60% rename from tests/drivers/test_mysql.py rename to tests/drivers/test_paths.py index e51aec44..4379bce3 100644 --- a/tests/drivers/test_mysql.py +++ b/tests/drivers/test_paths.py @@ -1,16 +1,16 @@ import pytest TEST_CASES = { - 'mysql://root@localhost:5000/test': dict( + '{provider}://root@localhost:5000/test': dict( username='root', password=None, host='localhost', port=5000, database='test' ), 'root@localhost:5000/test': dict( username='root', password=None, host='localhost', port=5000, database='test' ), - 'mysql://root:foo@localhost/test': dict( + '{provider}://root:foo@localhost/test': dict( username='root', password='foo', host='localhost', port=None, database='test' ), - 'mysql://localhost/test': dict( + '{provider}://localhost/test': dict( password=None, host='localhost', port=None, database='test' ), 'localhost/test': dict( @@ -20,25 +20,29 @@ INVALID_TEST_CASES = [ 'http://localhost/test', # wrong scheme - 'mysql://localhost', # no database - 'mysql://localhost/test/foo/bar' # path too deep + '{provider}://localhost', # no database + '{provider}://localhost/test/foo/bar' # path too deep ] @pytest.mark.parametrize('case', TEST_CASES.keys()) -def test_path_parsing(case): +@pytest.mark.parametrize('provider', ('mysql', 'postgresql')) +def test_path_parsing(case, provider): from terracotta import drivers # empty cache drivers._DRIVER_CACHE = {} - db = drivers.get_driver(case, provider='mysql') + db_path = case.format(provider=provider) + db = drivers.get_driver(db_path, provider=provider) for attr in ('username', 'password', 'host', 'port', 'database'): assert getattr(db.meta_store.url, attr) == TEST_CASES[case].get(attr, None) @pytest.mark.parametrize('case', INVALID_TEST_CASES) -def test_invalid_paths(case): +@pytest.mark.parametrize('provider', ('mysql', 'postgresql')) +def test_invalid_paths(case, provider): from terracotta import drivers + db_path = case.format(provider=provider) with pytest.raises(ValueError): - drivers.get_driver(case, provider='mysql') + drivers.get_driver(db_path, provider=provider)