Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqllab): TRINO_EXPAND_ROWS: expand columns from ROWs #25809

Merged
merged 1 commit into from
Nov 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
feat(sqllab): TRINO_EXPAND_ROWS: expand trino ROWs
Reimplement, in a simpler way, part of the PRESTO_EXPAND_DATA feature
for trino.

Make use of the trino library's type expansion logic added since the
original feature was introduced, as we now have sqla ROW types parsed
out recursively in a way we can analyse.

Analyse those ROWs and expand our definition of get_columns out to
include dotted.path references to all fields it's possible to query by
dotted path (i.e. all ROWs, nested arbitrarily deep).

Add an extra optional query_as field to the column definition so that we
can override the way it's queried: otherwise sqlalchemy is going to
quote the whole thing as a single column name, which isn't correct, it
should actually be quoted per dotted segment, and we'll want to alias it
to the full dotted string.

Add a setting in the database modal to enable this feature.
giftig committed Nov 2, 2023
commit c9c5eba155badae28242696e051fe1552b9bb1bc
4 changes: 2 additions & 2 deletions superset-frontend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -202,7 +202,7 @@ const ExtraOptions = ({
/>
</div>
</StyledInputContainer>
<StyledInputContainer>
<StyledInputContainer css={no_margin_bottom}>
<div className="input-container">
<IndeterminateCheckbox
id="disable_data_preview"
@@ -220,6 +220,22 @@ const ExtraOptions = ({
/>
</div>
</StyledInputContainer>
<StyledInputContainer>
giftig marked this conversation as resolved.
Show resolved Hide resolved
<div className="input-container">
<IndeterminateCheckbox
giftig marked this conversation as resolved.
Show resolved Hide resolved
id="expand_rows"
indeterminate={false}
checked={!!extraJson?.schema_options?.expand_rows}
onChange={onExtraInputChange}
labelText={t('Enable row expansion in schemas')}
/>
<InfoTooltip
tooltip={t(
'For Trino, describe full schemas of nested ROW types, expanding them with dotted paths',
)}
/>
</div>
</StyledInputContainer>
</StyledExpandableForm>
</StyledInputContainer>
</Collapse.Panel>
Original file line number Diff line number Diff line change
@@ -674,7 +674,7 @@ describe('DatabaseModal', () => {
const exposeInSQLLabCheckbox = screen.getByRole('checkbox', {
name: /expose database in sql lab/i,
});
// This is both the checkbox and it's respective SVG
// This is both the checkbox and its respective SVG
// const exposeInSQLLabCheckboxSVG = checkboxOffSVGs[0].parentElement;
const exposeInSQLLabText = screen.getByText(
/expose database in sql lab/i,
@@ -721,6 +721,13 @@ describe('DatabaseModal', () => {
/Disable SQL Lab data preview queries/i,
);

const enableRowExpansionCheckbox = screen.getByRole('checkbox', {
name: /enable row expansion in schemas/i,
});
const enableRowExpansionText = screen.getByText(
/enable row expansion in schemas/i,
);

// ---------- Assertions ----------
const visibleComponents = [
closeButton,
@@ -737,13 +744,15 @@ describe('DatabaseModal', () => {
checkboxOffSVGs[2],
checkboxOffSVGs[3],
checkboxOffSVGs[4],
checkboxOffSVGs[5],
tooltipIcons[0],
tooltipIcons[1],
tooltipIcons[2],
tooltipIcons[3],
tooltipIcons[4],
tooltipIcons[5],
tooltipIcons[6],
tooltipIcons[7],
exposeInSQLLabText,
allowCTASText,
allowCVASText,
@@ -754,6 +763,7 @@ describe('DatabaseModal', () => {
enableQueryCostEstimationText,
allowDbExplorationText,
disableSQLLabDataPreviewQueriesText,
enableRowExpansionText,
];
// These components exist in the DOM but are not visible
const invisibleComponents = [
@@ -764,15 +774,16 @@ describe('DatabaseModal', () => {
enableQueryCostEstimationCheckbox,
allowDbExplorationCheckbox,
disableSQLLabDataPreviewQueriesCheckbox,
enableRowExpansionCheckbox,
];
visibleComponents.forEach(component => {
expect(component).toBeVisible();
});
invisibleComponents.forEach(component => {
expect(component).not.toBeVisible();
});
expect(checkboxOffSVGs).toHaveLength(5);
expect(tooltipIcons).toHaveLength(7);
expect(checkboxOffSVGs).toHaveLength(6);
expect(tooltipIcons).toHaveLength(8);
});

test('renders the "Advanced" - PERFORMANCE tab correctly', async () => {
12 changes: 12 additions & 0 deletions superset-frontend/src/features/databases/DatabaseModal/index.tsx
Original file line number Diff line number Diff line change
@@ -307,6 +307,18 @@ export function dbReducer(
}),
};
}
if (action.payload.name === 'expand_rows') {
return {
...trimmedState,
extra: JSON.stringify({
...extraJson,
schema_options: {
...extraJson?.schema_options,
[action.payload.name]: !!action.payload.value,
},
}),
};
}
return {
...trimmedState,
extra: JSON.stringify({
3 changes: 3 additions & 0 deletions superset-frontend/src/features/databases/types.ts
Original file line number Diff line number Diff line change
@@ -226,5 +226,8 @@ export interface ExtraJson {
table_cache_timeout?: number; // in Performance
}; // No field, holds schema and table timeout
schemas_allowed_for_file_upload?: string[]; // in Security
schema_options?: {
expand_rows?: boolean;
};
version?: string;
}
19 changes: 15 additions & 4 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql import literal_column, quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import TypeEngine
from sqlparse.tokens import CTE
@@ -1309,15 +1309,21 @@ def get_table_comment(
return comment

@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
def get_columns( # pylint: disable=unused-argument
cls,
inspector: Inspector,
table_name: str,
schema: str | None,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
Get all columns from a given schema and table

:param inspector: SqlAlchemy Inspector instance
:param table_name: Table name
:param schema: Schema name. If omitted, uses default schema for database
:param options: Extra options to customise the display of columns in
some databases
:return: All columns in table
"""
return convert_inspector_columns(
@@ -1369,7 +1375,12 @@ def where_latest_partition( # pylint: disable=too-many-arguments,unused-argumen

@classmethod
def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]:
return [column(c["column_name"]) for c in cols]
return [
literal_column(query_as)
if (query_as := c.get("query_as"))
else column(c["column_name"])
for c in cols
]

@classmethod
def select_star( # pylint: disable=too-many-arguments,too-many-locals
11 changes: 0 additions & 11 deletions superset/db_engine_specs/druid.py
Original file line number Diff line number Diff line change
@@ -23,14 +23,12 @@
from typing import Any, TYPE_CHECKING

from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector

from superset import is_feature_enabled
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
from superset.exceptions import SupersetException
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils

if TYPE_CHECKING:
@@ -130,15 +128,6 @@ def epoch_ms_to_dttm(cls) -> str:
"""
return "MILLIS_TO_TIMESTAMP({col})"

@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
) -> list[ResultSetColumnType]:
"""
Update the Druid type map.
"""
return super().get_columns(inspector, table_name, schema)

@classmethod
def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-outside-toplevel
8 changes: 6 additions & 2 deletions superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
@@ -410,9 +410,13 @@ def handle_cursor( # pylint: disable=too-many-locals

@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
cls,
inspector: Inspector,
table_name: str,
schema: str | None,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
return BaseEngineSpec.get_columns(inspector, table_name, schema)
return BaseEngineSpec.get_columns(inspector, table_name, schema, options)

@classmethod
def where_latest_partition( # pylint: disable=too-many-arguments
7 changes: 6 additions & 1 deletion superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
@@ -981,14 +981,19 @@ def _show_columns(

@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: str | None
cls,
inspector: Inspector,
table_name: str,
schema: str | None,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
Get columns from a Presto data source. This includes handling row and
array data types
:param inspector: object that performs database schema inspection
:param table_name: table name
:param schema: schema name
:param options: Extra configuration options, not used by this backend
:return: a list of results that contain column info
(i.e. column name and data type)
"""
62 changes: 62 additions & 0 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
@@ -24,15 +24,18 @@

import simplejson as json
from flask import current_app
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session
from trino.sqlalchemy import datatype

from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError
from superset.db_engine_specs.presto import PrestoBaseEngineSpec
from superset.models.sql_lab import Query
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils

if TYPE_CHECKING:
@@ -325,3 +328,62 @@ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
return {
requests_exceptions.ConnectionError: SupersetDBAPIConnectionError,
}

@classmethod
def _expand_columns(cls, col: ResultSetColumnType) -> list[ResultSetColumnType]:
"""
Expand the given column out to one or more columns by analysing their types,
descending into ROWS and expanding out their inner fields recursively.

We can only navigate named fields in ROWs in this way, so we can't expand out
MAP or ARRAY types, nor fields in ROWs which have no name (in fact the trino
library doesn't correctly parse unnamed fields in ROWs). We won't be able to
expand ROWs which are nested underneath any of those types, either.

Expanded columns are named foo.bar.baz and we provide a query_as property to
instruct the base engine spec how to correctly query them: instead of quoting
the whole string they have to be quoted like "foo"."bar"."baz" and we then
alias them to the full dotted string for ease of reference.
"""
cols = [col]
col_type = col.get("type")

if not isinstance(col_type, datatype.ROW):
return cols

for inner_name, inner_type in col_type.attr_types:
outer_name = col["name"]
name = ".".join([outer_name, inner_name])
query_name = ".".join([f'"{piece}"' for piece in name.split(".")])
column_spec = cls.get_column_spec(str(inner_type))
is_dttm = column_spec.is_dttm if column_spec else False

inner_col = ResultSetColumnType(
name=name,
column_name=name,
type=inner_type,
is_dttm=is_dttm,
query_as=f'{query_name} AS "{name}"',
)
cols.extend(cls._expand_columns(inner_col))

return cols

@classmethod
def get_columns(
cls,
inspector: Inspector,
table_name: str,
schema: str | None,
options: dict[str, Any] | None = None,
) -> list[ResultSetColumnType]:
"""
If the "expand_rows" feature is enabled on the database via
"schema_options", expand the schema definition out to show all
subfields of nested ROWs as their appropriate dotted paths.
"""
base_cols = super().get_columns(inspector, table_name, schema, options)
if not (options or {}).get("expand_rows"):
return base_cols

return [col for base_col in base_cols for col in cls._expand_columns(base_col)]
10 changes: 9 additions & 1 deletion superset/models/core.py
Original file line number Diff line number Diff line change
@@ -236,6 +236,11 @@ def disable_data_preview(self) -> bool:
# this will prevent any 'trash value' strings from going through
return self.get_extra().get("disable_data_preview", False) is True

@property
def schema_options(self) -> dict[str, Any]:
"""Additional schema display config for engines with complex schemas"""
return self.get_extra().get("schema_options", {})

@property
def data(self) -> dict[str, Any]:
return {
@@ -247,6 +252,7 @@ def data(self) -> dict[str, Any]:
"allows_cost_estimate": self.allows_cost_estimate,
"allows_virtual_table_explore": self.allows_virtual_table_explore,
"explore_database_id": self.explore_database_id,
"schema_options": self.schema_options,
"parameters": self.parameters,
"disable_data_preview": self.disable_data_preview,
"parameters_schema": self.parameters_schema,
@@ -837,7 +843,9 @@ def get_columns(
self, table_name: str, schema: str | None = None
) -> list[ResultSetColumnType]:
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_columns(inspector, table_name, schema)
return self.db_engine_spec.get_columns(
inspector, table_name, schema, self.schema_options
)

def get_metrics(
self,
2 changes: 2 additions & 0 deletions superset/superset_typing.py
Original file line number Diff line number Diff line change
@@ -84,6 +84,8 @@ class ResultSetColumnType(TypedDict):
scale: NotRequired[Any]
max_length: NotRequired[Any]

query_as: NotRequired[Any]


CacheConfig = dict[str, Any]
DbapiDescriptionRow = tuple[
122 changes: 122 additions & 0 deletions tests/unit_tests/db_engine_specs/test_trino.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import copy
import json
from datetime import datetime
from typing import Any, Optional
@@ -24,9 +25,11 @@
import pytest
from pytest_mock import MockerFixture
from sqlalchemy import types
from trino.sqlalchemy import datatype

import superset.config
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
assert_column_spec,
@@ -35,6 +38,24 @@
from tests.unit_tests.fixtures.common import dttm


def _assert_columns_equal(actual_cols, expected_cols) -> None:
"""
Assert equality of the given cols, bearing in mind sqlalchemy type
instances can't be compared for equality, so will have to be converted to
strings first.
"""
actual = copy.deepcopy(actual_cols)
expected = copy.deepcopy(expected_cols)

for col in actual:
col["type"] = str(col["type"])

for col in expected:
col["type"] = str(col["type"])

assert actual == expected


@pytest.mark.parametrize(
"extra,expected",
[
@@ -395,3 +416,104 @@ def _mock_execute(*args, **kwargs):
mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)


def test_get_columns(mocker: MockerFixture):
"""Test that ROW columns are not expanded without expand_rows"""
from superset.db_engine_specs.trino import TrinoEngineSpec

field1_type = datatype.parse_sqltype("row(a varchar, b date)")
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
field3_type = datatype.parse_sqltype("int")

sqla_columns = [
SQLAColumnType(name="field1", type=field1_type, is_dttm=False),
SQLAColumnType(name="field2", type=field2_type, is_dttm=False),
SQLAColumnType(name="field3", type=field3_type, is_dttm=False),
]
mock_inspector = mocker.MagicMock()
mock_inspector.get_columns.return_value = sqla_columns

actual = TrinoEngineSpec.get_columns(mock_inspector, "table", "schema")
expected = [
ResultSetColumnType(
name="field1", column_name="field1", type=field1_type, is_dttm=False
),
ResultSetColumnType(
name="field2", column_name="field2", type=field2_type, is_dttm=False
),
ResultSetColumnType(
name="field3", column_name="field3", type=field3_type, is_dttm=False
),
]

_assert_columns_equal(actual, expected)


def test_get_columns_expand_rows(mocker: MockerFixture):
"""Test that ROW columns are correctly expanded with expand_rows"""
from superset.db_engine_specs.trino import TrinoEngineSpec

field1_type = datatype.parse_sqltype("row(a varchar, b date)")
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
field3_type = datatype.parse_sqltype("int")

sqla_columns = [
SQLAColumnType(name="field1", type=field1_type, is_dttm=False),
SQLAColumnType(name="field2", type=field2_type, is_dttm=False),
SQLAColumnType(name="field3", type=field3_type, is_dttm=False),
]
mock_inspector = mocker.MagicMock()
mock_inspector.get_columns.return_value = sqla_columns

actual = TrinoEngineSpec.get_columns(
mock_inspector, "table", "schema", {"expand_rows": True}
)
expected = [
ResultSetColumnType(
name="field1", column_name="field1", type=field1_type, is_dttm=False
),
ResultSetColumnType(
name="field1.a",
column_name="field1.a",
type=types.VARCHAR(),
is_dttm=False,
query_as='"field1"."a" AS "field1.a"',
),
ResultSetColumnType(
name="field1.b",
column_name="field1.b",
type=types.DATE(),
is_dttm=True,
query_as='"field1"."b" AS "field1.b"',
),
ResultSetColumnType(
name="field2", column_name="field2", type=field2_type, is_dttm=False
),
ResultSetColumnType(
name="field2.r1",
column_name="field2.r1",
type=datatype.parse_sqltype("row(a varchar, b varchar)"),
is_dttm=False,
query_as='"field2"."r1" AS "field2.r1"',
),
ResultSetColumnType(
name="field2.r1.a",
column_name="field2.r1.a",
type=types.VARCHAR(),
is_dttm=False,
query_as='"field2"."r1"."a" AS "field2.r1.a"',
),
ResultSetColumnType(
name="field2.r1.b",
column_name="field2.r1.b",
type=types.VARCHAR(),
is_dttm=False,
query_as='"field2"."r1"."b" AS "field2.r1.b"',
),
ResultSetColumnType(
name="field3", column_name="field3", type=field3_type, is_dttm=False
),
]

_assert_columns_equal(actual, expected)