From a3e1ed2f306d898c2a8ce9c8cd03de9508ddc9b2 Mon Sep 17 00:00:00 2001 From: Craig Rueda Date: Tue, 2 Apr 2024 09:00:32 -0700 Subject: [PATCH] feat(db): Adding DB_SQLA_URI_VALIDATOR (#27847) --- superset/config.py | 12 +++++++++++ superset/db_engine_specs/base.py | 3 +++ tests/unit_tests/db_engine_specs/test_base.py | 20 +++++++++++++++++++ .../unit_tests/extensions/test_sqlalchemy.py | 2 +- 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/superset/config.py b/superset/config.py index 0d00fedfbb9e3..04b0b909a61d3 100644 --- a/superset/config.py +++ b/superset/config.py @@ -44,6 +44,7 @@ from flask_caching.backends.base import BaseCache from pandas import Series from pandas._libs.parsers import STR_NA_VALUES # pylint: disable=no-name-in-module +from sqlalchemy.engine.url import URL from sqlalchemy.orm.query import Query from superset.advanced_data_type.plugins.internet_address import internet_address @@ -1207,6 +1208,17 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name DB_CONNECTION_MUTATOR = None +# A callable that is invoked for every invocation of DB Engine Specs +# which allows for custom validation of the engine URI. +# See: superset.db_engine_specs.base.BaseEngineSpec.validate_database_uri +# Example: +# def DB_ENGINE_URI_VALIDATOR(sqlalchemy_uri: URL): +# if not : +# raise Exception("URI invalid") +# +DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None + + # A function that intercepts the SQL to be executed and can alter it. # The use case is can be around adding some sort of comment header # with information such as the username and worker node information diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index e8790bdcd4f77..afe3592805446 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1956,6 +1956,9 @@ def validate_database_uri(cls, sqlalchemy_uri: URL) -> None: :param sqlalchemy_uri: """ + if db_engine_uri_validator := current_app.config["DB_SQLA_URI_VALIDATOR"]: + db_engine_uri_validator(sqlalchemy_uri) + if existing_disallowed := cls.disallow_uri_query_params.get( sqlalchemy_uri.get_driver_name(), set() ).intersection(sqlalchemy_uri.query): diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 86eb37183f25d..4e2d3f4b44a89 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -23,6 +23,7 @@ from pytest_mock import MockFixture from sqlalchemy import types from sqlalchemy.dialects import sqlite +from sqlalchemy.engine.url import URL from sqlalchemy.sql import sqltypes from superset.superset_typing import ResultSetColumnType, SQLAColumnType @@ -69,6 +70,25 @@ def test_parse_sql_multi_statement() -> None: ] +def test_validate_db_uri(mocker: MockFixture) -> None: + """ + Ensures that the `validate_database_uri` method invokes the validator correctly + """ + + def mock_validate(sqlalchemy_uri: URL) -> None: + raise ValueError("Invalid URI") + + mocker.patch( + "superset.db_engine_specs.base.current_app.config", + {"DB_SQLA_URI_VALIDATOR": mock_validate}, + ) + + from superset.db_engine_specs.base import BaseEngineSpec + + with pytest.raises(ValueError): + BaseEngineSpec.validate_database_uri(URL.create("sqlite")) + + @pytest.mark.parametrize( "original,expected", [ diff --git a/tests/unit_tests/extensions/test_sqlalchemy.py b/tests/unit_tests/extensions/test_sqlalchemy.py index caa141aaf7f14..c0fd49f9eb0e2 100644 --- a/tests/unit_tests/extensions/test_sqlalchemy.py +++ b/tests/unit_tests/extensions/test_sqlalchemy.py @@ -124,7 +124,7 @@ def test_superset_limit(mocker: MockFixture, app_context: None, table1: None) -> """ mocker.patch( "superset.extensions.metadb.current_app.config", - {"SUPERSET_META_DB_LIMIT": 1}, + {"DB_SQLA_URI_VALIDATOR": None, "SUPERSET_META_DB_LIMIT": 1}, ) mocker.patch("superset.extensions.metadb.security_manager")