diff --git a/superset/__init__.py b/superset/__init__.py index 6df897f3ecdb1..5c8ff3ca2dc57 100644 --- a/superset/__init__.py +++ b/superset/__init__.py @@ -19,7 +19,6 @@ from werkzeug.local import LocalProxy from superset.app import create_app -from superset.connectors.connector_registry import ConnectorRegistry from superset.extensions import ( appbuilder, cache_manager, diff --git a/superset/cachekeys/api.py b/superset/cachekeys/api.py index 6eb0d54d9eef0..e686c0a6df9e7 100644 --- a/superset/cachekeys/api.py +++ b/superset/cachekeys/api.py @@ -25,7 +25,7 @@ from sqlalchemy.exc import SQLAlchemyError from superset.cachekeys.schemas import CacheInvalidationRequestSchema -from superset.connectors.connector_registry import ConnectorRegistry +from superset.connectors.sqla.models import SqlaTable from superset.extensions import cache_manager, db, event_logger from superset.models.cache import CacheKey from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics @@ -83,13 +83,13 @@ def invalidate(self) -> Response: return self.response_400(message=str(error)) datasource_uids = set(datasources.get("datasource_uids", [])) for ds in datasources.get("datasources", []): - ds_obj = ConnectorRegistry.get_datasource_by_name( + ds_obj = SqlaTable.get_datasource_by_name( session=db.session, - datasource_type=ds.get("datasource_type"), datasource_name=ds.get("datasource_name"), schema=ds.get("schema"), database_name=ds.get("database_name"), ) + if ds_obj: datasource_uids.add(ds_obj.uid) diff --git a/superset/commands/utils.py b/superset/commands/utils.py index f7564b3de7689..0be5c52e31fd7 100644 --- a/superset/commands/utils.py +++ b/superset/commands/utils.py @@ -25,9 +25,10 @@ OwnersNotFoundValidationError, RolesNotFoundValidationError, ) -from superset.connectors.connector_registry import ConnectorRegistry -from superset.datasets.commands.exceptions import DatasetNotFoundError +from superset.dao.exceptions import DatasourceNotFound +from superset.datasource.dao import DatasourceDAO from superset.extensions import db, security_manager +from superset.utils.core import DatasourceType if TYPE_CHECKING: from superset.connectors.base.models import BaseDatasource @@ -79,8 +80,8 @@ def populate_roles(role_ids: Optional[List[int]] = None) -> List[Role]: def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource: try: - return ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session + return DatasourceDAO.get_datasource( + db.session, DatasourceType(datasource_type), datasource_id ) - except DatasetNotFoundError as ex: + except DatasourceNotFound as ex: raise DatasourceNotFoundValidationError() from ex diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index 2056109bbff70..1e1d16985ad43 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -22,8 +22,8 @@ from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.query_context import QueryContext from superset.common.query_object_factory import QueryObjectFactory -from superset.connectors.connector_registry import ConnectorRegistry -from superset.utils.core import DatasourceDict +from superset.datasource.dao import DatasourceDAO +from superset.utils.core import DatasourceDict, DatasourceType if TYPE_CHECKING: from superset.connectors.base.models import BaseDatasource @@ -32,7 +32,7 @@ def create_query_object_factory() -> QueryObjectFactory: - return QueryObjectFactory(config, ConnectorRegistry(), db.session) + return QueryObjectFactory(config, DatasourceDAO(), db.session) class QueryContextFactory: # pylint: disable=too-few-public-methods @@ -82,6 +82,6 @@ def create( # pylint: disable=no-self-use def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: - return ConnectorRegistry.get_datasource( - str(datasource["type"]), int(datasource["id"]), db.session + return DatasourceDAO.get_datasource( + db.session, DatasourceType(datasource["type"]), int(datasource["id"]) ) diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py index 64ae99deebabc..e9f5122975b52 100644 --- a/superset/common/query_object_factory.py +++ b/superset/common/query_object_factory.py @@ -21,29 +21,29 @@ from superset.common.chart_data import ChartDataResultType from superset.common.query_object import QueryObject -from superset.utils.core import apply_max_row_limit, DatasourceDict +from superset.utils.core import apply_max_row_limit, DatasourceDict, DatasourceType from superset.utils.date_parser import get_since_until if TYPE_CHECKING: from sqlalchemy.orm import sessionmaker - from superset import ConnectorRegistry from superset.connectors.base.models import BaseDatasource + from superset.datasource.dao import DatasourceDAO class QueryObjectFactory: # pylint: disable=too-few-public-methods _config: Dict[str, Any] - _connector_registry: ConnectorRegistry + _datasource_dao: DatasourceDAO _session_maker: sessionmaker def __init__( self, app_configurations: Dict[str, Any], - connector_registry: ConnectorRegistry, + _datasource_dao: DatasourceDAO, session_maker: sessionmaker, ): self._config = app_configurations - self._connector_registry = connector_registry + self._datasource_dao = _datasource_dao self._session_maker = session_maker def create( # pylint: disable=too-many-arguments @@ -75,8 +75,10 @@ def create( # pylint: disable=too-many-arguments ) def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: - return self._connector_registry.get_datasource( - str(datasource["type"]), int(datasource["id"]), self._session_maker() + return self._datasource_dao.get_datasource( + datasource_type=DatasourceType(datasource["type"]), + datasource_id=int(datasource["id"]), + session=self._session_maker(), ) def _process_extras( # pylint: disable=no-self-use diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py deleted file mode 100644 index 06816fa53049f..0000000000000 --- a/superset/connectors/connector_registry.py +++ /dev/null @@ -1,164 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING - -from flask_babel import _ -from sqlalchemy import or_ -from sqlalchemy.orm import Session, subqueryload -from sqlalchemy.orm.exc import NoResultFound - -from superset.datasets.commands.exceptions import DatasetNotFoundError - -if TYPE_CHECKING: - from collections import OrderedDict - - from superset.connectors.base.models import BaseDatasource - from superset.models.core import Database - - -class ConnectorRegistry: - """Central Registry for all available datasource engines""" - - sources: Dict[str, Type["BaseDatasource"]] = {} - - @classmethod - def register_sources(cls, datasource_config: "OrderedDict[str, List[str]]") -> None: - for module_name, class_names in datasource_config.items(): - class_names = [str(s) for s in class_names] - module_obj = __import__(module_name, fromlist=class_names) - for class_name in class_names: - source_class = getattr(module_obj, class_name) - cls.sources[source_class.type] = source_class - - @classmethod - def get_datasource( - cls, datasource_type: str, datasource_id: int, session: Session - ) -> "BaseDatasource": - """Safely get a datasource instance, raises `DatasetNotFoundError` if - `datasource_type` is not registered or `datasource_id` does not - exist.""" - if datasource_type not in cls.sources: - raise DatasetNotFoundError() - - datasource = ( - session.query(cls.sources[datasource_type]) - .filter_by(id=datasource_id) - .one_or_none() - ) - - if not datasource: - raise DatasetNotFoundError() - - return datasource - - @classmethod - def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]: - datasources: List["BaseDatasource"] = [] - for source_class in ConnectorRegistry.sources.values(): - qry = session.query(source_class) - qry = source_class.default_query(qry) - datasources.extend(qry.all()) - return datasources - - @classmethod - def get_datasource_by_id( - cls, session: Session, datasource_id: int - ) -> "BaseDatasource": - """ - Find a datasource instance based on the unique id. - - :param session: Session to use - :param datasource_id: unique id of datasource - :return: Datasource corresponding to the id - :raises NoResultFound: if no datasource is found corresponding to the id - """ - for datasource_class in ConnectorRegistry.sources.values(): - try: - return ( - session.query(datasource_class) - .filter(datasource_class.id == datasource_id) - .one() - ) - except NoResultFound: - # proceed to next datasource type - pass - raise NoResultFound(_("Datasource id not found: %(id)s", id=datasource_id)) - - @classmethod - def get_datasource_by_name( # pylint: disable=too-many-arguments - cls, - session: Session, - datasource_type: str, - datasource_name: str, - schema: str, - database_name: str, - ) -> Optional["BaseDatasource"]: - datasource_class = ConnectorRegistry.sources[datasource_type] - return datasource_class.get_datasource_by_name( - session, datasource_name, schema, database_name - ) - - @classmethod - def query_datasources_by_permissions( # pylint: disable=invalid-name - cls, - session: Session, - database: "Database", - permissions: Set[str], - schema_perms: Set[str], - ) -> List["BaseDatasource"]: - # TODO(bogdan): add unit test - datasource_class = ConnectorRegistry.sources[database.type] - return ( - session.query(datasource_class) - .filter_by(database_id=database.id) - .filter( - or_( - datasource_class.perm.in_(permissions), - datasource_class.schema_perm.in_(schema_perms), - ) - ) - .all() - ) - - @classmethod - def get_eager_datasource( - cls, session: Session, datasource_type: str, datasource_id: int - ) -> "BaseDatasource": - """Returns datasource with columns and metrics.""" - datasource_class = ConnectorRegistry.sources[datasource_type] - return ( - session.query(datasource_class) - .options( - subqueryload(datasource_class.columns), - subqueryload(datasource_class.metrics), - ) - .filter_by(id=datasource_id) - .one() - ) - - @classmethod - def query_datasources_by_name( - cls, - session: Session, - database: "Database", - datasource_name: str, - schema: Optional[str] = None, - ) -> List["BaseDatasource"]: - datasource_class = ConnectorRegistry.sources[database.type] - return datasource_class.query_datasources_by_name( - session, database, datasource_name, schema=schema - ) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index ff90cb2a56fef..57730cc711898 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -31,6 +31,7 @@ List, NamedTuple, Optional, + Set, Tuple, Type, Union, @@ -1990,6 +1991,48 @@ def query_datasources_by_name( query = query.filter_by(schema=schema) return query.all() + @classmethod + def query_datasources_by_permissions( # pylint: disable=invalid-name + cls, + session: Session, + database: Database, + permissions: Set[str], + schema_perms: Set[str], + ) -> List["SqlaTable"]: + # TODO(hughhhh): add unit test + return ( + session.query(cls) + .filter_by(database_id=database.id) + .filter( + or_( + SqlaTable.perm.in_(permissions), + SqlaTable.schema_perm.in_(schema_perms), + ) + ) + .all() + ) + + @classmethod + def get_eager_sqlatable_datasource( + cls, session: Session, datasource_id: int + ) -> "SqlaTable": + """Returns SqlaTable with columns and metrics.""" + return ( + session.query(cls) + .options( + sa.orm.subqueryload(cls.columns), + sa.orm.subqueryload(cls.metrics), + ) + .filter_by(id=datasource_id) + .one() + ) + + @classmethod + def get_all_datasources(cls, session: Session) -> List["SqlaTable"]: + qry = session.query(cls) + qry = cls.default_query(qry) + return qry.all() + @staticmethod def default_query(qry: Query) -> Query: return qry.filter_by(is_sqllab_view=False) diff --git a/superset/dao/datasource/dao.py b/superset/dao/datasource/dao.py deleted file mode 100644 index caa45564aa250..0000000000000 --- a/superset/dao/datasource/dao.py +++ /dev/null @@ -1,147 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from enum import Enum -from typing import Any, Dict, List, Optional, Set, Type, Union - -from flask_babel import _ -from sqlalchemy import or_ -from sqlalchemy.orm import Session, subqueryload -from sqlalchemy.orm.exc import NoResultFound - -from superset.connectors.sqla.models import SqlaTable -from superset.dao.base import BaseDAO -from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError -from superset.datasets.commands.exceptions import DatasetNotFoundError -from superset.datasets.models import Dataset -from superset.models.core import Database -from superset.models.sql_lab import Query, SavedQuery -from superset.tables.models import Table -from superset.utils.core import DatasourceType - -Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery] - - -class DatasourceDAO(BaseDAO): - - sources: Dict[DatasourceType, Type[Datasource]] = { - DatasourceType.TABLE: SqlaTable, - DatasourceType.QUERY: Query, - DatasourceType.SAVEDQUERY: SavedQuery, - DatasourceType.DATASET: Dataset, - DatasourceType.SLTABLE: Table, - } - - @classmethod - def get_datasource( - cls, session: Session, datasource_type: DatasourceType, datasource_id: int - ) -> Datasource: - if datasource_type not in cls.sources: - raise DatasourceTypeNotSupportedError() - - datasource = ( - session.query(cls.sources[datasource_type]) - .filter_by(id=datasource_id) - .one_or_none() - ) - - if not datasource: - raise DatasourceNotFound() - - return datasource - - @classmethod - def get_all_sqlatables_datasources(cls, session: Session) -> List[Datasource]: - source_class = DatasourceDAO.sources[DatasourceType.TABLE] - qry = session.query(source_class) - qry = source_class.default_query(qry) - return qry.all() - - @classmethod - def get_datasource_by_name( # pylint: disable=too-many-arguments - cls, - session: Session, - datasource_type: DatasourceType, - datasource_name: str, - database_name: str, - schema: str, - ) -> Optional[Datasource]: - datasource_class = DatasourceDAO.sources[datasource_type] - if isinstance(datasource_class, SqlaTable): - return datasource_class.get_datasource_by_name( - session, datasource_name, schema, database_name - ) - return None - - @classmethod - def query_datasources_by_permissions( # pylint: disable=invalid-name - cls, - session: Session, - database: Database, - permissions: Set[str], - schema_perms: Set[str], - ) -> List[Datasource]: - # TODO(hughhhh): add unit test - datasource_class = DatasourceDAO.sources[DatasourceType[database.type]] - if not isinstance(datasource_class, SqlaTable): - return [] - - return ( - session.query(datasource_class) - .filter_by(database_id=database.id) - .filter( - or_( - datasource_class.perm.in_(permissions), - datasource_class.schema_perm.in_(schema_perms), - ) - ) - .all() - ) - - @classmethod - def get_eager_datasource( - cls, session: Session, datasource_type: str, datasource_id: int - ) -> Optional[Datasource]: - """Returns datasource with columns and metrics.""" - datasource_class = DatasourceDAO.sources[DatasourceType[datasource_type]] - if not isinstance(datasource_class, SqlaTable): - return None - return ( - session.query(datasource_class) - .options( - subqueryload(datasource_class.columns), - subqueryload(datasource_class.metrics), - ) - .filter_by(id=datasource_id) - .one() - ) - - @classmethod - def query_datasources_by_name( - cls, - session: Session, - database: Database, - datasource_name: str, - schema: Optional[str] = None, - ) -> List[Datasource]: - datasource_class = DatasourceDAO.sources[DatasourceType[database.type]] - if not isinstance(datasource_class, SqlaTable): - return [] - - return datasource_class.query_datasources_by_name( - session, database, datasource_name, schema=schema - ) diff --git a/superset/dao/exceptions.py b/superset/dao/exceptions.py index 9b5624bd5d31d..93cb25d3fc70e 100644 --- a/superset/dao/exceptions.py +++ b/superset/dao/exceptions.py @@ -60,6 +60,7 @@ class DatasourceTypeNotSupportedError(DAOException): DAO datasource query source type is not supported """ + status = 422 message = "DAO datasource query source type is not supported" diff --git a/superset/dashboards/commands/importers/v0.py b/superset/dashboards/commands/importers/v0.py index 207920b1d2c2a..e49c931896838 100644 --- a/superset/dashboards/commands/importers/v0.py +++ b/superset/dashboards/commands/importers/v0.py @@ -24,7 +24,7 @@ from flask_babel import lazy_gettext as _ from sqlalchemy.orm import make_transient, Session -from superset import ConnectorRegistry, db +from superset import db from superset.commands.base import BaseCommand from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.datasets.commands.importers.v0 import import_dataset @@ -63,12 +63,11 @@ def import_chart( slc_to_import = slc_to_import.copy() slc_to_import.reset_ownership() params = slc_to_import.params_dict - datasource = ConnectorRegistry.get_datasource_by_name( - session, - slc_to_import.datasource_type, - params["datasource_name"], - params["schema"], - params["database_name"], + datasource = SqlaTable.get_datasource_by_name( + session=session, + datasource_name=params["datasource_name"], + database_name=params["database_name"], + schema=params["schema"], ) slc_to_import.datasource_id = datasource.id # type: ignore if slc_to_override: diff --git a/superset/datasource/__init__.py b/superset/datasource/__init__.py new file mode 100644 index 0000000000000..e0533d99236c2 --- /dev/null +++ b/superset/datasource/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/datasource/dao.py b/superset/datasource/dao.py new file mode 100644 index 0000000000000..c475919abf006 --- /dev/null +++ b/superset/datasource/dao.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Dict, Type, Union + +from sqlalchemy.orm import Session + +from superset.connectors.sqla.models import SqlaTable +from superset.dao.base import BaseDAO +from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError +from superset.datasets.models import Dataset +from superset.models.sql_lab import Query, SavedQuery +from superset.tables.models import Table +from superset.utils.core import DatasourceType + +Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery] + + +class DatasourceDAO(BaseDAO): + + sources: Dict[Union[DatasourceType, str], Type[Datasource]] = { + DatasourceType.TABLE: SqlaTable, + DatasourceType.QUERY: Query, + DatasourceType.SAVEDQUERY: SavedQuery, + DatasourceType.DATASET: Dataset, + DatasourceType.SLTABLE: Table, + } + + @classmethod + def get_datasource( + cls, + session: Session, + datasource_type: Union[DatasourceType, str], + datasource_id: int, + ) -> Datasource: + if datasource_type not in cls.sources: + raise DatasourceTypeNotSupportedError() + + datasource = ( + session.query(cls.sources[datasource_type]) + .filter_by(id=datasource_id) + .one_or_none() + ) + + if not datasource: + raise DatasourceNotFound() + + return datasource diff --git a/superset/examples/helpers.py b/superset/examples/helpers.py index 9d17e73773299..8c2ad29f49102 100644 --- a/superset/examples/helpers.py +++ b/superset/examples/helpers.py @@ -23,7 +23,7 @@ from urllib import request from superset import app, db -from superset.connectors.connector_registry import ConnectorRegistry +from superset.connectors.sqla.models import SqlaTable from superset.models.slice import Slice BASE_URL = "https://github.com/apache-superset/examples-data/blob/master/" @@ -32,7 +32,7 @@ def get_table_connector_registry() -> Any: - return ConnectorRegistry.sources["table"] + return SqlaTable def get_examples_folder() -> str: diff --git a/superset/explore/form_data/commands/create.py b/superset/explore/form_data/commands/create.py index 7946980c82684..5c301a96f1a12 100644 --- a/superset/explore/form_data/commands/create.py +++ b/superset/explore/form_data/commands/create.py @@ -27,6 +27,7 @@ from superset.key_value.utils import get_owner, random_key from superset.temporary_cache.commands.exceptions import TemporaryCacheCreateFailedError from superset.temporary_cache.utils import cache_key +from superset.utils.core import DatasourceType from superset.utils.schema import validate_json logger = logging.getLogger(__name__) @@ -56,7 +57,7 @@ def run(self) -> str: state: TemporaryExploreState = { "owner": get_owner(actor), "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_type": DatasourceType(datasource_type), "chart_id": chart_id, "form_data": form_data, } diff --git a/superset/explore/form_data/commands/state.py b/superset/explore/form_data/commands/state.py index 470f2e22f5989..35e3893478ea0 100644 --- a/superset/explore/form_data/commands/state.py +++ b/superset/explore/form_data/commands/state.py @@ -18,10 +18,12 @@ from typing_extensions import TypedDict +from superset.utils.core import DatasourceType + class TemporaryExploreState(TypedDict): owner: Optional[int] datasource_id: int - datasource_type: str + datasource_type: DatasourceType chart_id: Optional[int] form_data: str diff --git a/superset/explore/form_data/commands/update.py b/superset/explore/form_data/commands/update.py index fdc75093bef85..f48d8e85ef5ba 100644 --- a/superset/explore/form_data/commands/update.py +++ b/superset/explore/form_data/commands/update.py @@ -32,6 +32,7 @@ TemporaryCacheUpdateFailedError, ) from superset.temporary_cache.utils import cache_key +from superset.utils.core import DatasourceType from superset.utils.schema import validate_json logger = logging.getLogger(__name__) @@ -75,7 +76,7 @@ def run(self) -> Optional[str]: new_state: TemporaryExploreState = { "owner": owner, "datasource_id": datasource_id, - "datasource_type": datasource_type, + "datasource_type": DatasourceType(datasource_type), "chart_id": chart_id, "form_data": form_data, } diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 698c3881390ef..426dc1b524d19 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -28,7 +28,6 @@ from flask_compress import Compress from werkzeug.middleware.proxy_fix import ProxyFix -from superset.connectors.connector_registry import ConnectorRegistry from superset.constants import CHANGE_ME_SECRET_KEY from superset.extensions import ( _event_logger, @@ -473,7 +472,11 @@ def configure_data_sources(self) -> None: # Registering sources module_datasource_map = self.config["DEFAULT_MODULE_DS_MAP"] module_datasource_map.update(self.config["ADDITIONAL_MODULE_DS_MAP"]) - ConnectorRegistry.register_sources(module_datasource_map) + + # todo(hughhhh): fully remove the datasource config register + for module_name, class_names in module_datasource_map.items(): + class_names = [str(s) for s in class_names] + __import__(module_name, fromlist=class_names) def configure_cache(self) -> None: cache_manager.init_app(self.superset_app) diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index f2d53e1ff5e6a..12f7056161328 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -46,10 +46,11 @@ from sqlalchemy.sql import join, select from sqlalchemy.sql.elements import BinaryExpression -from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager +from superset import app, db, is_feature_enabled, security_manager from superset.common.request_contexed_based import is_user_admin from superset.connectors.base.models import BaseDatasource -from superset.connectors.sqla.models import SqlMetric, TableColumn +from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.datasource.dao import DatasourceDAO from superset.extensions import cache_manager from superset.models.filter_set import FilterSet from superset.models.helpers import AuditMixinNullable, ImportExportMixin @@ -407,16 +408,18 @@ def export_dashboards( # pylint: disable=too-many-locals id_ = target.get("datasetId") if id_ is None: continue - datasource = ConnectorRegistry.get_datasource_by_id(session, id_) + datasource = DatasourceDAO.get_datasource( + session, utils.DatasourceType.TABLE, id_ + ) datasource_ids.add((datasource.id, datasource.type)) copied_dashboard.alter_params(remote_id=dashboard_id) copied_dashboards.append(copied_dashboard) eager_datasources = [] - for datasource_id, datasource_type in datasource_ids: - eager_datasource = ConnectorRegistry.get_eager_datasource( - db.session, datasource_type, datasource_id + for datasource_id, _ in datasource_ids: + eager_datasource = SqlaTable.get_eager_sqlatable_datasource( + db.session, datasource_id ) copied_datasource = eager_datasource.copy() copied_datasource.alter_params( diff --git a/superset/models/datasource_access_request.py b/superset/models/datasource_access_request.py index fa3b9d67113d3..60bfe08238284 100644 --- a/superset/models/datasource_access_request.py +++ b/superset/models/datasource_access_request.py @@ -21,7 +21,6 @@ from sqlalchemy import Column, Integer, String from superset import app, db, security_manager -from superset.connectors.connector_registry import ConnectorRegistry from superset.models.helpers import AuditMixinNullable from superset.utils.memoized import memoized @@ -44,7 +43,10 @@ class DatasourceAccessRequest(Model, AuditMixinNullable): @property def cls_model(self) -> Type["BaseDatasource"]: - return ConnectorRegistry.sources[self.datasource_type] + # pylint: disable=import-outside-toplevel + from superset.datasource.dao import DatasourceDAO + + return DatasourceDAO.sources[self.datasource_type] @property def username(self) -> Markup: diff --git a/superset/models/slice.py b/superset/models/slice.py index 862edb9ec8ce8..841539bc66573 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -39,7 +39,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm.mapper import Mapper -from superset import ConnectorRegistry, db, is_feature_enabled, security_manager +from superset import db, is_feature_enabled, security_manager from superset.legacy import update_time_range from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.models.tags import ChartUpdater @@ -126,7 +126,10 @@ def __repr__(self) -> str: @property def cls_model(self) -> Type["BaseDatasource"]: - return ConnectorRegistry.sources[self.datasource_type] + # pylint: disable=import-outside-toplevel + from superset.datasource.dao import DatasourceDAO + + return DatasourceDAO.sources[self.datasource_type] @property def datasource(self) -> Optional["BaseDatasource"]: diff --git a/superset/security/manager.py b/superset/security/manager.py index 6157959aa3739..890a09415ecb3 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -61,7 +61,6 @@ from sqlalchemy.orm.query import Query as SqlaQuery from superset import sql_parse -from superset.connectors.connector_registry import ConnectorRegistry from superset.constants import RouteMethod from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( @@ -471,23 +470,25 @@ def get_user_datasources(self) -> List["BaseDatasource"]: user_perms = self.user_view_menu_names("datasource_access") schema_perms = self.user_view_menu_names("schema_access") user_datasources = set() - for datasource_class in ConnectorRegistry.sources.values(): - user_datasources.update( - self.get_session.query(datasource_class) - .filter( - or_( - datasource_class.perm.in_(user_perms), - datasource_class.schema_perm.in_(schema_perms), - ) + + # pylint: disable=import-outside-toplevel + from superset.connectors.sqla.models import SqlaTable + + user_datasources.update( + self.get_session.query(SqlaTable) + .filter( + or_( + SqlaTable.perm.in_(user_perms), + SqlaTable.schema_perm.in_(schema_perms), ) - .all() ) + .all() + ) # group all datasources by database - all_datasources = ConnectorRegistry.get_all_datasources(self.get_session) - datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict( - set - ) + session = self.get_session + all_datasources = SqlaTable.get_all_datasources(session) + datasources_by_database: Dict["Database", Set["SqlaTable"]] = defaultdict(set) for datasource in all_datasources: datasources_by_database[datasource.database].add(datasource) @@ -599,6 +600,8 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name :param schema: The fallback SQL schema if not present in the table name :returns: The list of accessible SQL tables w/ schema """ + # pylint: disable=import-outside-toplevel + from superset.connectors.sqla.models import SqlaTable if self.can_access_database(database): return datasource_names @@ -610,7 +613,7 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name user_perms = self.user_view_menu_names("datasource_access") schema_perms = self.user_view_menu_names("schema_access") - user_datasources = ConnectorRegistry.query_datasources_by_permissions( + user_datasources = SqlaTable.query_datasources_by_permissions( self.get_session, database, user_perms, schema_perms ) if schema: @@ -660,6 +663,7 @@ def create_missing_perms(self) -> None: """ # pylint: disable=import-outside-toplevel + from superset.connectors.sqla.models import SqlaTable from superset.models import core as models logger.info("Fetching a set of all perms to lookup which ones are missing") @@ -668,13 +672,13 @@ def create_missing_perms(self) -> None: if pv.permission and pv.view_menu: all_pvs.add((pv.permission.name, pv.view_menu.name)) - def merge_pv(view_menu: str, perm: str) -> None: + def merge_pv(view_menu: str, perm: Optional[str]) -> None: """Create permission view menu only if it doesn't exist""" if view_menu and perm and (view_menu, perm) not in all_pvs: self.add_permission_view_menu(view_menu, perm) logger.info("Creating missing datasource permissions.") - datasources = ConnectorRegistry.get_all_datasources(self.get_session) + datasources = SqlaTable.get_all_datasources(self.get_session) for datasource in datasources: merge_pv("datasource_access", datasource.get_perm()) merge_pv("schema_access", datasource.get_schema_perm()) diff --git a/superset/views/core.py b/superset/views/core.py index f65385fc305a5..04d3835f60322 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -61,7 +61,6 @@ from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.db_query_status import QueryStatus from superset.connectors.base.models import BaseDatasource -from superset.connectors.connector_registry import ConnectorRegistry from superset.connectors.sqla.models import ( AnnotationDatasource, SqlaTable, @@ -77,6 +76,7 @@ from superset.databases.filters import DatabaseFilter from superset.databases.utils import make_url_safe from superset.datasets.commands.exceptions import DatasetNotFoundError +from superset.datasource.dao import DatasourceDAO from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( CacheLoadError, @@ -129,7 +129,11 @@ from superset.utils import core as utils, csv from superset.utils.async_query_manager import AsyncQueryTokenException from superset.utils.cache import etag_cache -from superset.utils.core import apply_max_row_limit, ReservedUrlParameters +from superset.utils.core import ( + apply_max_row_limit, + DatasourceType, + ReservedUrlParameters, +) from superset.utils.dates import now_as_float from superset.utils.decorators import check_dashboard_access from superset.views.base import ( @@ -250,7 +254,7 @@ def override_role_permissions(self) -> FlaskResponse: ) db_ds_names.add(fullname) - existing_datasources = ConnectorRegistry.get_all_datasources(db.session) + existing_datasources = SqlaTable.get_all_datasources(db.session) datasources = [d for d in existing_datasources if d.full_name in db_ds_names] role = security_manager.find_role(role_name) # remove all permissions @@ -282,7 +286,7 @@ def request_access(self) -> FlaskResponse: datasource_id = request.args.get("datasource_id") datasource_type = request.args.get("datasource_type") if datasource_id and datasource_type: - ds_class = ConnectorRegistry.sources.get(datasource_type) + ds_class = DatasourceDAO.sources.get(datasource_type) datasource = ( db.session.query(ds_class).filter_by(id=int(datasource_id)).one() ) @@ -319,10 +323,8 @@ def request_access(self) -> FlaskResponse: def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-use def clean_fulfilled_requests(session: Session) -> None: for dar in session.query(DAR).all(): - datasource = ConnectorRegistry.get_datasource( - dar.datasource_type, - dar.datasource_id, - session, + datasource = DatasourceDAO.get_datasource( + session, DatasourceType(dar.datasource_type), dar.datasource_id ) if not datasource or security_manager.can_access_datasource(datasource): # Dataset does not exist anymore @@ -336,8 +338,8 @@ def clean_fulfilled_requests(session: Session) -> None: role_to_extend = request.args.get("role_to_extend") session = db.session - datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, session + datasource = DatasourceDAO.get_datasource( + session, DatasourceType(datasource_type), int(datasource_id) ) if not datasource: @@ -639,7 +641,6 @@ def explore_json( datasource_id, datasource_type = get_datasource_info( datasource_id, datasource_type, form_data ) - force = request.args.get("force") == "true" # TODO: support CSV, SQL query and other non-JSON types @@ -809,8 +810,10 @@ def explore( datasource: Optional[BaseDatasource] = None if datasource_id is not None: try: - datasource = ConnectorRegistry.get_datasource( - cast(str, datasource_type), datasource_id, db.session + datasource = DatasourceDAO.get_datasource( + db.session, + DatasourceType(cast(str, datasource_type)), + datasource_id, ) except DatasetNotFoundError: pass @@ -948,10 +951,8 @@ def filter( # pylint: disable=no-self-use :raises SupersetSecurityException: If the user cannot access the resource """ # TODO: Cache endpoint by user, datasource and column - datasource = ConnectorRegistry.get_datasource( - datasource_type, - datasource_id, - db.session, + datasource = DatasourceDAO.get_datasource( + db.session, DatasourceType(datasource_type), datasource_id ) if not datasource: return json_error_response(DATASOURCE_MISSING_ERR) @@ -1920,8 +1921,8 @@ def dashboard( if config["ENABLE_ACCESS_REQUEST"]: for datasource in dashboard.datasources: - datasource = ConnectorRegistry.get_datasource( - datasource_type=datasource.type, + datasource = DatasourceDAO.get_datasource( + datasource_type=DatasourceType(datasource.type), datasource_id=datasource.id, session=db.session(), ) @@ -2537,10 +2538,8 @@ def fetch_datasource_metadata(self) -> FlaskResponse: # pylint: disable=no-self """ datasource_id, datasource_type = request.args["datasourceKey"].split("__") - datasource = ConnectorRegistry.get_datasource( - datasource_type, - datasource_id, - db.session, + datasource = DatasourceDAO.get_datasource( + db.session, DatasourceType(datasource_type), int(datasource_id) ) # Check if datasource exists if not datasource: diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py index 560c12d6f19b5..bf67eddd01199 100644 --- a/superset/views/datasource/views.py +++ b/superset/views/datasource/views.py @@ -29,16 +29,18 @@ from superset import db, event_logger from superset.commands.utils import populate_owners -from superset.connectors.connector_registry import ConnectorRegistry +from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.utils import get_physical_table_metadata from superset.datasets.commands.exceptions import ( DatasetForbiddenError, DatasetNotFoundError, ) +from superset.datasource.dao import DatasourceDAO from superset.exceptions import SupersetException, SupersetSecurityException from superset.extensions import security_manager from superset.models.core import Database from superset.superset_typing import FlaskResponse +from superset.utils.core import DatasourceType from superset.views.base import ( api, BaseSupersetView, @@ -74,8 +76,8 @@ def save(self) -> FlaskResponse: datasource_id = datasource_dict.get("id") datasource_type = datasource_dict.get("type") database_id = datasource_dict["database"].get("id") - orm_datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session + orm_datasource = DatasourceDAO.get_datasource( + db.session, DatasourceType(datasource_type), datasource_id ) orm_datasource.database_id = database_id @@ -117,8 +119,8 @@ def save(self) -> FlaskResponse: @api @handle_api_exception def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse: - datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session + datasource = DatasourceDAO.get_datasource( + db.session, DatasourceType(datasource_type), datasource_id ) return self.json_response(sanitize_datasource_data(datasource.data)) @@ -130,8 +132,10 @@ def external_metadata( self, datasource_type: str, datasource_id: int ) -> FlaskResponse: """Gets column info from the source system""" - datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session + datasource = DatasourceDAO.get_datasource( + db.session, + DatasourceType(datasource_type), + datasource_id, ) try: external_metadata = datasource.external_metadata() @@ -153,9 +157,8 @@ def external_metadata_by_name(self, **kwargs: Any) -> FlaskResponse: except ValidationError as err: return json_error_response(str(err), status=400) - datasource = ConnectorRegistry.get_datasource_by_name( + datasource = SqlaTable.get_datasource_by_name( session=db.session, - datasource_type=params["datasource_type"], database_name=params["database_name"], schema=params["schema_name"], datasource_name=params["table_name"], diff --git a/superset/views/utils.py b/superset/views/utils.py index e0f97cba1839b..719642ef13a96 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -32,7 +32,7 @@ import superset.models.core as models from superset import app, dataframe, db, result_set, viz from superset.common.db_query_status import QueryStatus -from superset.connectors.connector_registry import ConnectorRegistry +from superset.datasource.dao import DatasourceDAO from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( CacheLoadError, @@ -47,6 +47,7 @@ from superset.models.slice import Slice from superset.models.sql_lab import Query from superset.superset_typing import FormData +from superset.utils.core import DatasourceType from superset.utils.decorators import stats_timing from superset.viz import BaseViz @@ -127,8 +128,10 @@ def get_viz( force_cached: bool = False, ) -> BaseViz: viz_type = form_data.get("viz_type", "table") - datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session + datasource = DatasourceDAO.get_datasource( + db.session, + DatasourceType(datasource_type), + datasource_id, ) viz_obj = viz.viz_types[viz_type]( datasource, form_data=form_data, force=force, force_cached=force_cached diff --git a/tests/integration_tests/access_tests.py b/tests/integration_tests/access_tests.py index c2319ff5b52c1..59b5b19b298c4 100644 --- a/tests/integration_tests/access_tests.py +++ b/tests/integration_tests/access_tests.py @@ -39,7 +39,6 @@ ) from tests.integration_tests.test_app import app # isort:skip from superset import db, security_manager -from superset.connectors.connector_registry import ConnectorRegistry from superset.connectors.sqla.models import SqlaTable from superset.models import core as models from superset.models.datasource_access_request import DatasourceAccessRequest @@ -90,12 +89,12 @@ def create_access_request(session, ds_type, ds_name, role_name, username): - ds_class = ConnectorRegistry.sources[ds_type] # TODO: generalize datasource names if ds_type == "table": - ds = session.query(ds_class).filter(ds_class.table_name == ds_name).first() + ds = session.query(SqlaTable).filter(SqlaTable.table_name == ds_name).first() else: - ds = session.query(ds_class).filter(ds_class.datasource_name == ds_name).first() + # This function will only work for ds_type == "table" + raise NotImplementedError() ds_perm_view = security_manager.find_permission_view_menu( "datasource_access", ds.perm ) @@ -449,49 +448,6 @@ def test_approve(self, mock_send_mime): TEST_ROLE = security_manager.find_role(TEST_ROLE_NAME) self.assertIn(perm_view, TEST_ROLE.permissions) - # Case 3. Grant new role to the user to access the druid datasource. - - security_manager.add_role("druid_role") - access_request3 = create_access_request( - session, "druid", "druid_ds_1", "druid_role", "gamma" - ) - self.get_resp( - GRANT_ROLE_REQUEST.format( - "druid", access_request3.datasource_id, "gamma", "druid_role" - ) - ) - - # user was granted table_role - user_roles = [r.name for r in security_manager.find_user("gamma").roles] - self.assertIn("druid_role", user_roles) - - # Case 4. Extend the role to have access to the druid datasource - - access_request4 = create_access_request( - session, "druid", "druid_ds_2", "druid_role", "gamma" - ) - druid_ds_2_perm = access_request4.datasource.perm - - self.client.get( - EXTEND_ROLE_REQUEST.format( - "druid", access_request4.datasource_id, "gamma", "druid_role" - ) - ) - # druid_role was extended to grant access to the druid_access_ds_2 - druid_role = security_manager.find_role("druid_role") - perm_view = security_manager.find_permission_view_menu( - "datasource_access", druid_ds_2_perm - ) - self.assertIn(perm_view, druid_role.permissions) - - # cleanup - gamma_user = security_manager.find_user(username="gamma") - gamma_user.roles.remove(security_manager.find_role("druid_role")) - gamma_user.roles.remove(security_manager.find_role(TEST_ROLE_NAME)) - session.delete(security_manager.find_role("druid_role")) - session.delete(security_manager.find_role(TEST_ROLE_NAME)) - session.commit() - def test_request_access(self): if app.config["ENABLE_ACCESS_REQUEST"]: session = db.session diff --git a/tests/integration_tests/dashboard_utils.py b/tests/integration_tests/dashboard_utils.py index 41a34fa36edf5..115d3269f2e50 100644 --- a/tests/integration_tests/dashboard_utils.py +++ b/tests/integration_tests/dashboard_utils.py @@ -21,7 +21,7 @@ from pandas import DataFrame -from superset import ConnectorRegistry, db +from superset import db from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database from superset.models.dashboard import Dashboard @@ -35,9 +35,8 @@ def get_table( schema: Optional[str] = None, ): schema = schema or get_example_default_schema() - table_source = ConnectorRegistry.sources["table"] return ( - db.session.query(table_source) + db.session.query(SqlaTable) .filter_by(database_id=database.id, schema=schema, table_name=table_name) .one_or_none() ) @@ -54,8 +53,7 @@ def create_table_metadata( table = get_table(table_name, database, schema) if not table: - table_source = ConnectorRegistry.sources["table"] - table = table_source(schema=schema, table_name=table_name) + table = SqlaTable(schema=schema, table_name=table_name) if fetch_values_predicate: table.fetch_values_predicate = fetch_values_predicate table.database = database diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 6d46afa0a9ddd..6c8ae672c5845 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -22,12 +22,13 @@ import prison import pytest -from superset import app, ConnectorRegistry, db +from superset import app, db from superset.connectors.sqla.models import SqlaTable +from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.exceptions import SupersetGenericDBErrorException from superset.models.core import Database -from superset.utils.core import get_example_default_schema +from superset.utils.core import DatasourceType, get_example_default_schema from superset.utils.database import get_example_database from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( @@ -256,9 +257,10 @@ def test_external_metadata_error_return_400(self, mock_get_datasource): pytest.raises( SupersetGenericDBErrorException, - lambda: ConnectorRegistry.get_datasource( - "table", tbl.id, db.session - ).external_metadata(), + lambda: db.session.query(SqlaTable) + .filter_by(id=tbl.id) + .one_or_none() + .external_metadata(), ) resp = self.client.get(url) @@ -385,21 +387,30 @@ def my_check(datasource): app.config["DATASET_HEALTH_CHECK"] = my_check self.login(username="admin") tbl = self.get_table(name="birth_names") - datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session) + datasource = db.session.query(SqlaTable).filter_by(id=tbl.id).one_or_none() assert datasource.health_check_message == "Warning message!" app.config["DATASET_HEALTH_CHECK"] = None def test_get_datasource_failed(self): + from superset.datasource.dao import DatasourceDAO + pytest.raises( - DatasetNotFoundError, - lambda: ConnectorRegistry.get_datasource("table", 9999999, db.session), + DatasourceNotFound, + lambda: DatasourceDAO.get_datasource(db.session, "table", 9999999), ) self.login(username="admin") - resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False) - self.assertEqual(resp.get("error"), "Dataset does not exist") + resp = self.get_json_resp("/datasource/get/table/500000/", raise_on_error=False) + self.assertEqual(resp.get("error"), "Datasource does not exist") - resp = self.get_json_resp( - "/datasource/get/invalid-datasource-type/500000/", raise_on_error=False + def test_get_datasource_invalid_datasource_failed(self): + from superset.datasource.dao import DatasourceDAO + + pytest.raises( + DatasourceTypeNotSupportedError, + lambda: DatasourceDAO.get_datasource(db.session, "druid", 9999999), ) - self.assertEqual(resp.get("error"), "Dataset does not exist") + + self.login(username="admin") + resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False) + self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType") diff --git a/tests/integration_tests/explore/form_data/api_tests.py b/tests/integration_tests/explore/form_data/api_tests.py index 8b375df56ae38..dae713ff7041b 100644 --- a/tests/integration_tests/explore/form_data/api_tests.py +++ b/tests/integration_tests/explore/form_data/api_tests.py @@ -26,6 +26,7 @@ from superset.explore.form_data.commands.state import TemporaryExploreState from superset.extensions import cache_manager from superset.models.slice import Slice +from superset.utils.core import DatasourceType from tests.integration_tests.base_tests import login from tests.integration_tests.fixtures.client import client from tests.integration_tests.fixtures.world_bank_dashboard import ( @@ -392,7 +393,7 @@ def test_delete_not_owner(client, chart_id: int, datasource: SqlaTable, admin_id entry: TemporaryExploreState = { "owner": another_owner, "datasource_id": datasource.id, - "datasource_type": datasource.type, + "datasource_type": DatasourceType(datasource.type), "chart_id": chart_id, "form_data": INITIAL_FORM_DATA, } diff --git a/tests/integration_tests/fixtures/birth_names_dashboard.py b/tests/integration_tests/fixtures/birth_names_dashboard.py index ef71803aa5db7..0434e22295267 100644 --- a/tests/integration_tests/fixtures/birth_names_dashboard.py +++ b/tests/integration_tests/fixtures/birth_names_dashboard.py @@ -18,7 +18,7 @@ import pytest -from superset import ConnectorRegistry, db +from superset import db from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database from superset.models.dashboard import Dashboard @@ -95,14 +95,11 @@ def _create_table( def _cleanup(dash_id: int, slices_ids: List[int]) -> None: schema = get_example_default_schema() - - table_id = ( + datasource = ( db.session.query(SqlaTable) .filter_by(table_name="birth_names", schema=schema) .one() - .id ) - datasource = ConnectorRegistry.get_datasource("table", table_id, db.session) columns = [column for column in datasource.columns] metrics = [metric for metric in datasource.metrics] diff --git a/tests/integration_tests/fixtures/energy_dashboard.py b/tests/integration_tests/fixtures/energy_dashboard.py index c0291db2a9864..0279fe8ff2f5c 100644 --- a/tests/integration_tests/fixtures/energy_dashboard.py +++ b/tests/integration_tests/fixtures/energy_dashboard.py @@ -82,7 +82,6 @@ def _create_energy_table(): table.metrics.append( SqlMetric(metric_name="sum__value", expression=f"SUM({col})") ) - db.session.merge(table) db.session.commit() table.fetch_metadata() diff --git a/tests/integration_tests/insert_chart_mixin.py b/tests/integration_tests/insert_chart_mixin.py index 8fcb33067e351..da05d0c49d043 100644 --- a/tests/integration_tests/insert_chart_mixin.py +++ b/tests/integration_tests/insert_chart_mixin.py @@ -16,7 +16,8 @@ # under the License. from typing import List, Optional -from superset import ConnectorRegistry, db, security_manager +from superset import db, security_manager +from superset.connectors.sqla.models import SqlaTable from superset.models.slice import Slice @@ -43,8 +44,8 @@ def insert_chart( for owner in owners: user = db.session.query(security_manager.user_model).get(owner) obj_owners.append(user) - datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session + datasource = ( + db.session.query(SqlaTable).filter_by(id=datasource_id).one_or_none() ) slice = Slice( cache_timeout=cache_timeout, diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 816267678f9e0..6d5fec88f444d 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -26,10 +26,15 @@ from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.query_context import QueryContext from superset.common.query_object import QueryObject -from superset.connectors.connector_registry import ConnectorRegistry from superset.connectors.sqla.models import SqlMetric +from superset.datasource.dao import DatasourceDAO from superset.extensions import cache_manager -from superset.utils.core import AdhocMetricExpressionType, backend, QueryStatus +from superset.utils.core import ( + AdhocMetricExpressionType, + backend, + DatasourceType, + QueryStatus, +) from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, @@ -132,10 +137,10 @@ def test_query_cache_key_changes_when_datasource_is_updated(self): cache_key_original = query_context.query_cache_key(query_object) # make temporary change and revert it to refresh the changed_on property - datasource = ConnectorRegistry.get_datasource( - datasource_type=payload["datasource"]["type"], - datasource_id=payload["datasource"]["id"], + datasource = DatasourceDAO.get_datasource( session=db.session, + datasource_type=DatasourceType(payload["datasource"]["type"]), + datasource_id=payload["datasource"]["id"], ) description_original = datasource.description datasource.description = "temporary description" @@ -156,10 +161,10 @@ def test_query_cache_key_changes_when_metric_is_updated(self): payload = get_query_context("birth_names") # make temporary change and revert it to refresh the changed_on property - datasource = ConnectorRegistry.get_datasource( - datasource_type=payload["datasource"]["type"], - datasource_id=payload["datasource"]["id"], + datasource = DatasourceDAO.get_datasource( session=db.session, + datasource_type=DatasourceType(payload["datasource"]["type"]), + datasource_id=payload["datasource"]["id"], ) datasource.metrics.append(SqlMetric(metric_name="foo", expression="select 1;")) diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index a70146db68321..045e368296e85 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -28,10 +28,11 @@ import pytest from flask import current_app +from superset.datasource.dao import DatasourceDAO from superset.models.dashboard import Dashboard -from superset import app, appbuilder, db, security_manager, viz, ConnectorRegistry +from superset import app, appbuilder, db, security_manager, viz from superset.connectors.sqla.models import SqlaTable from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetSecurityException @@ -990,7 +991,7 @@ def test_get_user_datasources_admin( mock_get_session.query.return_value.filter.return_value.all.return_value = [] with mock.patch.object( - ConnectorRegistry, "get_all_datasources" + SqlaTable, "get_all_datasources" ) as mock_get_all_datasources: mock_get_all_datasources.return_value = [ Datasource("database1", "schema1", "table1"), @@ -1018,7 +1019,7 @@ def test_get_user_datasources_gamma( mock_get_session.query.return_value.filter.return_value.all.return_value = [] with mock.patch.object( - ConnectorRegistry, "get_all_datasources" + SqlaTable, "get_all_datasources" ) as mock_get_all_datasources: mock_get_all_datasources.return_value = [ Datasource("database1", "schema1", "table1"), @@ -1046,7 +1047,7 @@ def test_get_user_datasources_gamma_with_schema( ] with mock.patch.object( - ConnectorRegistry, "get_all_datasources" + SqlaTable, "get_all_datasources" ) as mock_get_all_datasources: mock_get_all_datasources.return_value = [ Datasource("database1", "schema1", "table1"), diff --git a/tests/unit_tests/dao/datasource_test.py b/tests/unit_tests/datasource/dao_tests.py similarity index 81% rename from tests/unit_tests/dao/datasource_test.py rename to tests/unit_tests/datasource/dao_tests.py index a15684d71e699..0682c19c28756 100644 --- a/tests/unit_tests/dao/datasource_test.py +++ b/tests/unit_tests/datasource/dao_tests.py @@ -103,7 +103,7 @@ def test_get_datasource_sqlatable( app_context: None, session_with_data: Session ) -> None: from superset.connectors.sqla.models import SqlaTable - from superset.dao.datasource.dao import DatasourceDAO + from superset.datasource.dao import DatasourceDAO result = DatasourceDAO.get_datasource( datasource_type=DatasourceType.TABLE, @@ -117,7 +117,7 @@ def test_get_datasource_sqlatable( def test_get_datasource_query(app_context: None, session_with_data: Session) -> None: - from superset.dao.datasource.dao import DatasourceDAO + from superset.datasource.dao import DatasourceDAO from superset.models.sql_lab import Query result = DatasourceDAO.get_datasource( @@ -131,7 +131,7 @@ def test_get_datasource_query(app_context: None, session_with_data: Session) -> def test_get_datasource_saved_query( app_context: None, session_with_data: Session ) -> None: - from superset.dao.datasource.dao import DatasourceDAO + from superset.datasource.dao import DatasourceDAO from superset.models.sql_lab import SavedQuery result = DatasourceDAO.get_datasource( @@ -145,7 +145,7 @@ def test_get_datasource_saved_query( def test_get_datasource_sl_table(app_context: None, session_with_data: Session) -> None: - from superset.dao.datasource.dao import DatasourceDAO + from superset.datasource.dao import DatasourceDAO from superset.tables.models import Table # todo(hugh): This will break once we remove the dual write @@ -163,8 +163,8 @@ def test_get_datasource_sl_table(app_context: None, session_with_data: Session) def test_get_datasource_sl_dataset( app_context: None, session_with_data: Session ) -> None: - from superset.dao.datasource.dao import DatasourceDAO from superset.datasets.models import Dataset + from superset.datasource.dao import DatasourceDAO # todo(hugh): This will break once we remove the dual write # update the datsource_id=1 and this will pass again @@ -178,10 +178,35 @@ def test_get_datasource_sl_dataset( assert isinstance(result, Dataset) -def test_get_all_sqlatables_datasources( +def test_get_datasource_w_str_param( app_context: None, session_with_data: Session ) -> None: - from superset.dao.datasource.dao import DatasourceDAO + from superset.connectors.sqla.models import SqlaTable + from superset.datasets.models import Dataset + from superset.datasource.dao import DatasourceDAO + from superset.tables.models import Table + + assert isinstance( + DatasourceDAO.get_datasource( + datasource_type="table", + datasource_id=1, + session=session_with_data, + ), + SqlaTable, + ) + + assert isinstance( + DatasourceDAO.get_datasource( + datasource_type="sl_table", + datasource_id=1, + session=session_with_data, + ), + Table, + ) + + +def test_get_all_datasources(app_context: None, session_with_data: Session) -> None: + from superset.connectors.sqla.models import SqlaTable - result = DatasourceDAO.get_all_sqlatables_datasources(session=session_with_data) + result = SqlaTable.get_all_datasources(session=session_with_data) assert len(result) == 1