diff --git a/terracotta/config.py b/terracotta/config.py index 86f0c66f..270e09ac 100644 --- a/terracotta/config.py +++ b/terracotta/config.py @@ -7,9 +7,12 @@ import os import json import tempfile +import warnings from marshmallow import Schema, fields, validate, pre_load, post_load, ValidationError +from terracotta import exceptions + class TerracottaSettings(NamedTuple): """Contains all settings for the current Terracotta instance.""" @@ -67,23 +70,37 @@ class TerracottaSettings(NamedTuple): #: CORS allowed origins for tiles endpoints ALLOWED_ORIGINS_TILES: List[str] = [r'http[s]?://(localhost|127\.0\.0\.1):*'] - #: MySQL database username (if not given in driver path) + #: SQL database username (if not given in driver path) + SQL_USER: Optional[str] = None + + #: SQL database password (if not given in driver path) + SQL_PASSWORD: Optional[str] = None + + #: Deprecated, use SQL_USER. MySQL database username (if not given in driver path) MYSQL_USER: Optional[str] = None - #: MySQL database password (if not given in driver path) + #: Deprecated, use SQL_PASSWORD. MySQL database password (if not given in driver path) MYSQL_PASSWORD: Optional[str] = None - #: PostgreSQL database username (if not given in driver path) + #: Deprecated, use SQL_USER. PostgreSQL database username (if not given in driver path) POSTGRESQL_USER: Optional[str] = None - #: PostgreSQL database password (if not given in driver path) + #: Deprecated, use SQL_PASSWORD. PostgreSQL database password (if not given in driver path) POSTGRESQL_PASSWORD: Optional[str] = None #: Use a process pool for band retrieval in parallel USE_MULTIPROCESSING: bool = True -AVAILABLE_SETTINGS: Tuple[str, ...] = tuple(TerracottaSettings._fields) +AVAILABLE_SETTINGS: Tuple[str, ...] = TerracottaSettings._fields + +DEPRECATION_MAP: Dict[str, str] = { + # TODO: Remove in v0.8.0 + 'MYSQL_USER': 'SQL_USER', + 'MYSQL_PASSWORD': 'SQL_PASSWORD', + 'POSTGRESQL_USER': 'SQL_USER', + 'POSTGRESQL_PASSWORD': 'SQL_PASSWORD', +} def _is_writable(path: str) -> bool: @@ -129,10 +146,13 @@ class SettingSchema(Schema): ALLOWED_ORIGINS_METADATA = fields.List(fields.String()) ALLOWED_ORIGINS_TILES = fields.List(fields.String()) - MYSQL_USER = fields.String() - MYSQL_PASSWORD = fields.String() - POSTGRESQL_USER = fields.String() - POSTGRESQL_PASSWORD = fields.String() + SQL_USER = fields.String(allow_none=True) + SQL_PASSWORD = fields.String(allow_none=True) + + MYSQL_USER = fields.String(allow_none=True) + MYSQL_PASSWORD = fields.String(allow_none=True) + POSTGRESQL_USER = fields.String(allow_none=True) + POSTGRESQL_PASSWORD = fields.String(allow_none=True) USE_MULTIPROCESSING = fields.Boolean() @@ -150,6 +170,23 @@ def decode_lists(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: ) from exc return data + @pre_load + def handle_deprecated_fields(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + for deprecated_field, new_field in DEPRECATION_MAP.items(): + if data.get(deprecated_field): + warnings.warn( + f'Setting TC_{deprecated_field} is deprecated ' + 'and will be removed in the next major release. ' + f'Please use TC_{new_field} instead.', + exceptions.DeprecationWarning + ) + + # Only use the mapping if the new field has not been set + if not data.get(new_field): + data[new_field] = data[deprecated_field] + + return data + @post_load def make_settings(self, data: Dict[str, Any], **kwargs: Any) -> TerracottaSettings: # encode tuples diff --git a/terracotta/drivers/relational_meta_store.py b/terracotta/drivers/relational_meta_store.py index 8e3e5a39..62179b7c 100644 --- a/terracotta/drivers/relational_meta_store.py +++ b/terracotta/drivers/relational_meta_store.py @@ -113,10 +113,11 @@ def _parse_path(cls, connection_string: str) -> URL: if con_params.scheme != cls.SQL_DIALECT: raise ValueError(f'unsupported URL scheme "{con_params.scheme}"') + settings = terracotta.get_settings() url = URL.create( drivername=f'{cls.SQL_DIALECT}+{cls.SQL_DRIVER}', - username=con_params.username, - password=con_params.password, + username=con_params.username or settings.SQL_USER, + password=con_params.password or settings.SQL_PASSWORD, host=con_params.hostname, port=con_params.port, database=con_params.path[1:], # remove leading '/' from urlparse diff --git a/terracotta/exceptions.py b/terracotta/exceptions.py index b6d904b2..3ba3eeab 100644 --- a/terracotta/exceptions.py +++ b/terracotta/exceptions.py @@ -26,3 +26,7 @@ class InvalidDatabaseError(Exception): class PerformanceWarning(UserWarning): pass + + +class DeprecationWarning(UserWarning): + pass diff --git a/tests/drivers/test_drivers.py b/tests/drivers/test_drivers.py index a584833f..ea8f1ccf 100644 --- a/tests/drivers/test_drivers.py +++ b/tests/drivers/test_drivers.py @@ -261,3 +261,23 @@ def test_invalid_key_types(driver_path, provider): with pytest.raises(exceptions.InvalidKeyError) as exc: db.get_datasets({'not-a-key': 'val'}) assert 'unrecognized keys' in str(exc) + + +@pytest.mark.parametrize('provider', TESTABLE_DRIVERS) +def test_use_credentials_from_settings(driver_path, provider, monkeypatch): + with monkeypatch.context() as m: + m.setenv('TC_SQL_USER', 'foo') + m.setenv('TC_SQL_PASSWORD', 'bar') + + from terracotta import drivers, update_settings + update_settings() + + if 'sqlite' not in provider: + meta_store_class = drivers.load_driver(provider) + assert meta_store_class._parse_path('').username == 'foo' + assert meta_store_class._parse_path('').password == 'bar' + + driver_path_without_credentials = driver_path[driver_path.find('@') + 1:] + db = drivers.get_driver(driver_path_without_credentials, provider) + assert db.meta_store.url.username == 'foo' + assert db.meta_store.url.password == 'bar' diff --git a/tests/test_config.py b/tests/test_config.py index 4d2af340..da3f944f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -76,3 +76,26 @@ def test_update_config(): update_settings(DEFAULT_TILE_SIZE=[50, 50]) new_settings = get_settings() assert new_settings.DRIVER_PATH == 'test' and new_settings.DEFAULT_TILE_SIZE == (50, 50) + + +def test_deprecation_behaviour(monkeypatch): + from terracotta import config, exceptions, get_settings, update_settings + for deprecated_field, new_field in config.DEPRECATION_MAP.items(): + with monkeypatch.context() as m: + m.setenv(f'TC_{deprecated_field}', 'foo') + + with pytest.warns(exceptions.DeprecationWarning) as warning: + update_settings() + assert f'TC_{deprecated_field} is deprecated' in str(warning[0]) + + assert getattr(get_settings(), deprecated_field) == 'foo' + assert getattr(get_settings(), new_field) == 'foo' + + m.setenv(f'TC_{new_field}', 'bar') + + with pytest.warns(exceptions.DeprecationWarning) as warning: + update_settings() + assert f'TC_{deprecated_field} is deprecated' in str(warning[0]) + + assert getattr(get_settings(), deprecated_field) == 'foo' + assert getattr(get_settings(), new_field) == 'bar'