Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(db): Adding DB_SQLA_URI_VALIDATOR #27847

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from flask_caching.backends.base import BaseCache
from pandas import Series
from pandas._libs.parsers import STR_NA_VALUES # pylint: disable=no-name-in-module
from sqlalchemy.engine.url import URL
from sqlalchemy.orm.query import Query

from superset.advanced_data_type.plugins.internet_address import internet_address
Expand Down Expand Up @@ -1206,6 +1207,17 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
DB_CONNECTION_MUTATOR = None


# A callable that is invoked for every invocation of DB Engine Specs
# which allows for custom validation of the engine URI.
# See: superset.db_engine_specs.base.BaseEngineSpec.validate_database_uri
# Example:
# def DB_ENGINE_URI_VALIDATOR(sqlalchemy_uri: URL):
# if not <some condition>:
# raise Exception("URI invalid")
#
DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None


# A function that intercepts the SQL to be executed and can alter it.
# The use case is can be around adding some sort of comment header
# with information such as the username and worker node information
Expand Down
3 changes: 3 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,6 +1956,9 @@ def validate_database_uri(cls, sqlalchemy_uri: URL) -> None:
:param sqlalchemy_uri:
"""
if db_engine_uri_validator := current_app.config["DB_SQLA_URI_VALIDATOR"]:
db_engine_uri_validator(sqlalchemy_uri)

if existing_disallowed := cls.disallow_uri_query_params.get(
sqlalchemy_uri.get_driver_name(), set()
).intersection(sqlalchemy_uri.query):
Expand Down
20 changes: 20 additions & 0 deletions tests/unit_tests/db_engine_specs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytest_mock import MockFixture
from sqlalchemy import types
from sqlalchemy.dialects import sqlite
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes

from superset.superset_typing import ResultSetColumnType, SQLAColumnType
Expand Down Expand Up @@ -69,6 +70,25 @@ def test_parse_sql_multi_statement() -> None:
]


def test_validate_db_uri(mocker: MockFixture) -> None:
"""
Ensures that the `validate_database_uri` method invokes the validator correctly
"""

def mock_validate(sqlalchemy_uri: URL) -> None:
raise ValueError("Invalid URI")

mocker.patch(
"superset.db_engine_specs.base.current_app.config",
{"DB_SQLA_URI_VALIDATOR": mock_validate},
)

from superset.db_engine_specs.base import BaseEngineSpec

with pytest.raises(ValueError):
BaseEngineSpec.validate_database_uri(URL.create("sqlite"))


@pytest.mark.parametrize(
"original,expected",
[
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/extensions/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_superset_limit(mocker: MockFixture, app_context: None, table1: None) ->
"""
mocker.patch(
"superset.extensions.metadb.current_app.config",
{"SUPERSET_META_DB_LIMIT": 1},
{"DB_SQLA_URI_VALIDATOR": None, "SUPERSET_META_DB_LIMIT": 1},
)
mocker.patch("superset.extensions.metadb.security_manager")

Expand Down
Loading