Skip to content

Commit

Permalink
address concerns
Browse files Browse the repository at this point in the history
  • Loading branch information
hughhhh authored May 12, 2022
1 parent bcfc683 commit 9983ce9
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 67 deletions.
78 changes: 39 additions & 39 deletions superset/dao/datasource/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from superset.connectors.sqla.models import SqlaTable
from superset.dao.base import BaseDAO
from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.datasets.models import Dataset
from superset.models.core import Database
Expand All @@ -46,10 +47,10 @@ class DatasourceDAO(BaseDAO):

@classmethod
def get_datasource(
cls, datasource_type: DatasourceType, datasource_id: int, session: Session
cls, session: Session, datasource_type: DatasourceType, datasource_id: int
) -> Datasource:
if datasource_type not in cls.sources:
raise DatasetNotFoundError()
raise DatasourceTypeNotSupportedError()

datasource = (
session.query(cls.sources[datasource_type])
Expand All @@ -58,28 +59,25 @@ def get_datasource(
)

if not datasource:
raise DatasetNotFoundError()
raise DatasourceNotFound()

return datasource

@classmethod
def get_all_datasources(cls, session: Session) -> List[Datasource]:
datasources: List[Datasource] = []
for source_class in DatasourceDAO.sources.values():
qry = session.query(source_class)
if isinstance(source_class, SqlaTable):
qry = source_class.default_query(qry)
datasources.extend(qry.all())
return datasources
def get_all_sqlatables_datasources(cls, session: Session) -> List[Datasource]:
source_class = DatasourceDAO.sources[DatasourceType.SQLATABLE]
qry = session.query(source_class)
qry = source_class.default_query(qry)
return qry.all()

@classmethod
def get_datasource_by_name( # pylint: disable=too-many-arguments
cls,
session: Session,
datasource_type: DatasourceType,
datasource_name: str,
schema: str,
database_name: str,
schema: str,
) -> Optional[Datasource]:
datasource_class = DatasourceDAO.sources[datasource_type]
if isinstance(datasource_class, SqlaTable):
Expand All @@ -96,39 +94,40 @@ def query_datasources_by_permissions( # pylint: disable=invalid-name
permissions: Set[str],
schema_perms: Set[str],
) -> List[Datasource]:
# TODO(bogdan): add unit test
# TODO(hughhhh): add unit test
datasource_class = DatasourceDAO.sources[DatasourceType[database.type]]
if isinstance(datasource_class, SqlaTable):
return (
session.query(datasource_class)
.filter_by(database_id=database.id)
.filter(
or_(
datasource_class.perm.in_(permissions),
datasource_class.schema_perm.in_(schema_perms),
)
if not isinstance(datasource_class, SqlaTable):
return []

return (
session.query(datasource_class)
.filter_by(database_id=database.id)
.filter(
or_(
datasource_class.perm.in_(permissions),
datasource_class.schema_perm.in_(schema_perms),
)
.all()
)
return []
.all()
)

@classmethod
def get_eager_datasource(
cls, session: Session, datasource_type: str, datasource_id: int
) -> Optional[Datasource]:
"""Returns datasource with columns and metrics."""
datasource_class = DatasourceDAO.sources[DatasourceType[datasource_type]]
if isinstance(datasource_class, SqlaTable):
return (
session.query(datasource_class)
.options(
subqueryload(datasource_class.columns),
subqueryload(datasource_class.metrics),
)
.filter_by(id=datasource_id)
.one()
if not isinstance(datasource_class, SqlaTable):
return None
return (
session.query(datasource_class)
.options(
subqueryload(datasource_class.columns),
subqueryload(datasource_class.metrics),
)
return None
.filter_by(id=datasource_id)
.one()
)

@classmethod
def query_datasources_by_name(
Expand All @@ -139,8 +138,9 @@ def query_datasources_by_name(
schema: Optional[str] = None,
) -> List[Datasource]:
datasource_class = DatasourceDAO.sources[DatasourceType[database.type]]
if isinstance(datasource_class, SqlaTable):
return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=schema
)
return []
if not isinstance(datasource_class, SqlaTable):
return []

return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=schema
)
12 changes: 12 additions & 0 deletions superset/dao/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,15 @@ class DAOConfigError(DAOException):
"""

message = "DAO is not configured correctly missing model definition"


class DatasourceTypeNotSupportedError(DAOException):
"""
DAO datasource query source type is not supported
"""

message = "DAO datasource query source type is not supported"


class DatasourceNotFound(DAOException):
message = "Datasource does not exist"
60 changes: 32 additions & 28 deletions tests/unit_tests/dao/datasource_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
# specific language governing permissions and limitations
# under the License.

from typing import Iterator

import pytest
from sqlalchemy.orm.session import Session

from superset.utils.core import DatasourceType


def create_test_data(session: Session) -> None:
@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
from superset.columns.models import Column
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.datasets.models import Dataset
Expand Down Expand Up @@ -93,89 +96,90 @@ def create_test_data(session: Session) -> None:
session.add(db)
session.add(sqla_table)
session.flush()
yield session


def test_get_datasource_sqlatable(app_context: None, session: Session) -> None:
def test_get_datasource_sqlatable(
app_context: None, session_with_data: Session
) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.dao.datasource.dao import DatasourceDAO

create_test_data(session)

result = DatasourceDAO.get_datasource(
datasource_type=DatasourceType.SQLATABLE, datasource_id=1, session=session
datasource_type=DatasourceType.SQLATABLE,
datasource_id=1,
session=session_with_data,
)

assert 1 == result.id
assert "my_sqla_table" == result.table_name
assert isinstance(result, SqlaTable)


def test_get_datasource_query(app_context: None, session: Session) -> None:
def test_get_datasource_query(app_context: None, session_with_data: Session) -> None:
from superset.dao.datasource.dao import DatasourceDAO
from superset.models.sql_lab import Query

create_test_data(session)

result = DatasourceDAO.get_datasource(
datasource_type=DatasourceType.QUERY, datasource_id=1, session=session
datasource_type=DatasourceType.QUERY, datasource_id=1, session=session_with_data
)

assert result.id == 1
assert isinstance(result, Query)


def test_get_datasource_saved_query(app_context: None, session: Session) -> None:
def test_get_datasource_saved_query(
app_context: None, session_with_data: Session
) -> None:
from superset.dao.datasource.dao import DatasourceDAO
from superset.models.sql_lab import SavedQuery

create_test_data(session)

result = DatasourceDAO.get_datasource(
datasource_type=DatasourceType.SAVEDQUERY, datasource_id=1, session=session
datasource_type=DatasourceType.SAVEDQUERY,
datasource_id=1,
session=session_with_data,
)

assert result.id == 1
assert isinstance(result, SavedQuery)


def test_get_datasource_sl_table(app_context: None, session: Session) -> None:
def test_get_datasource_sl_table(app_context: None, session_with_data: Session) -> None:
from superset.dao.datasource.dao import DatasourceDAO
from superset.tables.models import Table

create_test_data(session)

# todo(hugh): This will break once we remove the dual write
# update the datsource_id=1 and this will pass again
result = DatasourceDAO.get_datasource(
datasource_type=DatasourceType.TABLE, datasource_id=2, session=session
datasource_type=DatasourceType.TABLE, datasource_id=2, session=session_with_data
)

assert result.id == 2
assert isinstance(result, Table)


def test_get_datasource_sl_dataset(app_context: None, session: Session) -> None:
def test_get_datasource_sl_dataset(
app_context: None, session_with_data: Session
) -> None:
from superset.dao.datasource.dao import DatasourceDAO
from superset.datasets.models import Dataset

create_test_data(session)

# todo(hugh): This will break once we remove the dual write
# update the datsource_id=1 and this will pass again
result = DatasourceDAO.get_datasource(
datasource_type=DatasourceType.DATASET, datasource_id=2, session=session
datasource_type=DatasourceType.DATASET,
datasource_id=2,
session=session_with_data,
)

assert result.id == 2
assert isinstance(result, Dataset)


def test_get_all_datasources(app_context: None, session: Session) -> None:
def test_get_all_sqlatables_datasources(
app_context: None, session_with_data: Session
) -> None:
from superset.dao.datasource.dao import DatasourceDAO

create_test_data(session)

# todo(hugh): This will break once we remove the dual write
# update the assert len(result) == 5 and this will pass again
result = DatasourceDAO.get_all_datasources(session=session)
assert len(result) == 7
result = DatasourceDAO.get_all_sqlatables_datasources(session=session_with_data)
assert len(result) == 1

0 comments on commit 9983ce9

Please sign in to comment.