From c9c5eba155badae28242696e051fe1552b9bb1bc Mon Sep 17 00:00:00 2001 From: Rob Moore Date: Fri, 27 Oct 2023 15:51:47 +0100 Subject: [PATCH] feat(sqllab): TRINO_EXPAND_ROWS: expand trino ROWs Reimplement, in a simpler way, part of the PRESTO_EXPAND_DATA feature for trino. Make use of the trino library's type expansion logic added since the original feature was introduced, as we now have sqla ROW types parsed out recursively in a way we can analyse. Analyse those ROWs and expand our definition of get_columns out to include dotted.path references to all fields it's possible to query by dotted path (i.e. all ROWs, nested arbitrarily deep). Add an extra optional query_as field to the column definition so that we can override the way it's queried: otherwise sqlalchemy is going to quote the whole thing as a single column name, which isn't correct, it should actually be quoted per dotted segment, and we'll want to alias it to the full dotted string. Add a setting in the database modal to enable this feature. --- superset-frontend/package-lock.json | 4 +- .../databases/DatabaseModal/ExtraOptions.tsx | 18 ++- .../databases/DatabaseModal/index.test.tsx | 17 ++- .../databases/DatabaseModal/index.tsx | 12 ++ .../src/features/databases/types.ts | 3 + superset/db_engine_specs/base.py | 19 ++- superset/db_engine_specs/druid.py | 11 -- superset/db_engine_specs/hive.py | 8 +- superset/db_engine_specs/presto.py | 7 +- superset/db_engine_specs/trino.py | 62 +++++++++ superset/models/core.py | 10 +- superset/superset_typing.py | 2 + .../unit_tests/db_engine_specs/test_trino.py | 122 ++++++++++++++++++ 13 files changed, 270 insertions(+), 25 deletions(-) diff --git a/superset-frontend/package-lock.json b/superset-frontend/package-lock.json index 370cd3f1f246e..8174442a4567d 100644 --- a/superset-frontend/package-lock.json +++ b/superset-frontend/package-lock.json @@ -284,8 +284,8 @@ "webpack-sources": "^3.2.3" }, "engines": { - "node": "^16.9.1", - "npm": "^7.5.4 || ^8.1.2" + "node": "^16.20.2", + "npm": "^8.19.4" } }, "buildtools/eslint-plugin-theme-colors": { diff --git a/superset-frontend/src/features/databases/DatabaseModal/ExtraOptions.tsx b/superset-frontend/src/features/databases/DatabaseModal/ExtraOptions.tsx index 207e197cd40d9..51ab8524eeb8d 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/ExtraOptions.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/ExtraOptions.tsx @@ -202,7 +202,7 @@ const ExtraOptions = ({ /> - +
+ +
+ + +
+
diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx index bcd9fbe694706..ba443e0099457 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/index.test.tsx @@ -674,7 +674,7 @@ describe('DatabaseModal', () => { const exposeInSQLLabCheckbox = screen.getByRole('checkbox', { name: /expose database in sql lab/i, }); - // This is both the checkbox and it's respective SVG + // This is both the checkbox and its respective SVG // const exposeInSQLLabCheckboxSVG = checkboxOffSVGs[0].parentElement; const exposeInSQLLabText = screen.getByText( /expose database in sql lab/i, @@ -721,6 +721,13 @@ describe('DatabaseModal', () => { /Disable SQL Lab data preview queries/i, ); + const enableRowExpansionCheckbox = screen.getByRole('checkbox', { + name: /enable row expansion in schemas/i, + }); + const enableRowExpansionText = screen.getByText( + /enable row expansion in schemas/i, + ); + // ---------- Assertions ---------- const visibleComponents = [ closeButton, @@ -737,6 +744,7 @@ describe('DatabaseModal', () => { checkboxOffSVGs[2], checkboxOffSVGs[3], checkboxOffSVGs[4], + checkboxOffSVGs[5], tooltipIcons[0], tooltipIcons[1], tooltipIcons[2], @@ -744,6 +752,7 @@ describe('DatabaseModal', () => { tooltipIcons[4], tooltipIcons[5], tooltipIcons[6], + tooltipIcons[7], exposeInSQLLabText, allowCTASText, allowCVASText, @@ -754,6 +763,7 @@ describe('DatabaseModal', () => { enableQueryCostEstimationText, allowDbExplorationText, disableSQLLabDataPreviewQueriesText, + enableRowExpansionText, ]; // These components exist in the DOM but are not visible const invisibleComponents = [ @@ -764,6 +774,7 @@ describe('DatabaseModal', () => { enableQueryCostEstimationCheckbox, allowDbExplorationCheckbox, disableSQLLabDataPreviewQueriesCheckbox, + enableRowExpansionCheckbox, ]; visibleComponents.forEach(component => { expect(component).toBeVisible(); @@ -771,8 +782,8 @@ describe('DatabaseModal', () => { invisibleComponents.forEach(component => { expect(component).not.toBeVisible(); }); - expect(checkboxOffSVGs).toHaveLength(5); - expect(tooltipIcons).toHaveLength(7); + expect(checkboxOffSVGs).toHaveLength(6); + expect(tooltipIcons).toHaveLength(8); }); test('renders the "Advanced" - PERFORMANCE tab correctly', async () => { diff --git a/superset-frontend/src/features/databases/DatabaseModal/index.tsx b/superset-frontend/src/features/databases/DatabaseModal/index.tsx index 0c1ac56369692..18c93f2bf462f 100644 --- a/superset-frontend/src/features/databases/DatabaseModal/index.tsx +++ b/superset-frontend/src/features/databases/DatabaseModal/index.tsx @@ -307,6 +307,18 @@ export function dbReducer( }), }; } + if (action.payload.name === 'expand_rows') { + return { + ...trimmedState, + extra: JSON.stringify({ + ...extraJson, + schema_options: { + ...extraJson?.schema_options, + [action.payload.name]: !!action.payload.value, + }, + }), + }; + } return { ...trimmedState, extra: JSON.stringify({ diff --git a/superset-frontend/src/features/databases/types.ts b/superset-frontend/src/features/databases/types.ts index e138a9143669e..1d616fa13c053 100644 --- a/superset-frontend/src/features/databases/types.ts +++ b/superset-frontend/src/features/databases/types.ts @@ -226,5 +226,8 @@ export interface ExtraJson { table_cache_timeout?: number; // in Performance }; // No field, holds schema and table timeout schemas_allowed_for_file_upload?: string[]; // in Security + schema_options?: { + expand_rows?: boolean; + }; version?: string; } diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index f355e4ef8cea8..4ba58cb21deef 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -51,7 +51,7 @@ from sqlalchemy.engine.url import URL from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import Session -from sqlalchemy.sql import quoted_name, text +from sqlalchemy.sql import literal_column, quoted_name, text from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause from sqlalchemy.types import TypeEngine from sqlparse.tokens import CTE @@ -1309,8 +1309,12 @@ def get_table_comment( return comment @classmethod - def get_columns( - cls, inspector: Inspector, table_name: str, schema: str | None + def get_columns( # pylint: disable=unused-argument + cls, + inspector: Inspector, + table_name: str, + schema: str | None, + options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: """ Get all columns from a given schema and table @@ -1318,6 +1322,8 @@ def get_columns( :param inspector: SqlAlchemy Inspector instance :param table_name: Table name :param schema: Schema name. If omitted, uses default schema for database + :param options: Extra options to customise the display of columns in + some databases :return: All columns in table """ return convert_inspector_columns( @@ -1369,7 +1375,12 @@ def where_latest_partition( # pylint: disable=too-many-arguments,unused-argumen @classmethod def _get_fields(cls, cols: list[ResultSetColumnType]) -> list[Any]: - return [column(c["column_name"]) for c in cols] + return [ + literal_column(query_as) + if (query_as := c.get("query_as")) + else column(c["column_name"]) + for c in cols + ] @classmethod def select_star( # pylint: disable=too-many-arguments,too-many-locals diff --git a/superset/db_engine_specs/druid.py b/superset/db_engine_specs/druid.py index 9bba3a727438b..7cd85ec924cf9 100644 --- a/superset/db_engine_specs/druid.py +++ b/superset/db_engine_specs/druid.py @@ -23,14 +23,12 @@ from typing import Any, TYPE_CHECKING from sqlalchemy import types -from sqlalchemy.engine.reflection import Inspector from superset import is_feature_enabled from superset.constants import TimeGrain from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError from superset.exceptions import SupersetException -from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils if TYPE_CHECKING: @@ -130,15 +128,6 @@ def epoch_ms_to_dttm(cls) -> str: """ return "MILLIS_TO_TIMESTAMP({col})" - @classmethod - def get_columns( - cls, inspector: Inspector, table_name: str, schema: str | None - ) -> list[ResultSetColumnType]: - """ - Update the Druid type map. - """ - return super().get_columns(inspector, table_name, schema) - @classmethod def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-outside-toplevel diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 4a881e15b276b..bd303f928d625 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -410,9 +410,13 @@ def handle_cursor( # pylint: disable=too-many-locals @classmethod def get_columns( - cls, inspector: Inspector, table_name: str, schema: str | None + cls, + inspector: Inspector, + table_name: str, + schema: str | None, + options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: - return BaseEngineSpec.get_columns(inspector, table_name, schema) + return BaseEngineSpec.get_columns(inspector, table_name, schema, options) @classmethod def where_latest_partition( # pylint: disable=too-many-arguments diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 8afa82d9b55d9..27e86a7980875 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -981,7 +981,11 @@ def _show_columns( @classmethod def get_columns( - cls, inspector: Inspector, table_name: str, schema: str | None + cls, + inspector: Inspector, + table_name: str, + schema: str | None, + options: dict[str, Any] | None = None, ) -> list[ResultSetColumnType]: """ Get columns from a Presto data source. This includes handling row and @@ -989,6 +993,7 @@ def get_columns( :param inspector: object that performs database schema inspection :param table_name: table name :param schema: schema name + :param options: Extra configuration options, not used by this backend :return: a list of results that contain column info (i.e. column name and data type) """ diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 425137e302e6b..76b8dbea236b2 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -24,8 +24,10 @@ import simplejson as json from flask import current_app +from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.orm import Session +from trino.sqlalchemy import datatype from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT from superset.databases.utils import make_url_safe @@ -33,6 +35,7 @@ from superset.db_engine_specs.exceptions import SupersetDBAPIConnectionError from superset.db_engine_specs.presto import PrestoBaseEngineSpec from superset.models.sql_lab import Query +from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils if TYPE_CHECKING: @@ -325,3 +328,62 @@ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: return { requests_exceptions.ConnectionError: SupersetDBAPIConnectionError, } + + @classmethod + def _expand_columns(cls, col: ResultSetColumnType) -> list[ResultSetColumnType]: + """ + Expand the given column out to one or more columns by analysing their types, + descending into ROWS and expanding out their inner fields recursively. + + We can only navigate named fields in ROWs in this way, so we can't expand out + MAP or ARRAY types, nor fields in ROWs which have no name (in fact the trino + library doesn't correctly parse unnamed fields in ROWs). We won't be able to + expand ROWs which are nested underneath any of those types, either. + + Expanded columns are named foo.bar.baz and we provide a query_as property to + instruct the base engine spec how to correctly query them: instead of quoting + the whole string they have to be quoted like "foo"."bar"."baz" and we then + alias them to the full dotted string for ease of reference. + """ + cols = [col] + col_type = col.get("type") + + if not isinstance(col_type, datatype.ROW): + return cols + + for inner_name, inner_type in col_type.attr_types: + outer_name = col["name"] + name = ".".join([outer_name, inner_name]) + query_name = ".".join([f'"{piece}"' for piece in name.split(".")]) + column_spec = cls.get_column_spec(str(inner_type)) + is_dttm = column_spec.is_dttm if column_spec else False + + inner_col = ResultSetColumnType( + name=name, + column_name=name, + type=inner_type, + is_dttm=is_dttm, + query_as=f'{query_name} AS "{name}"', + ) + cols.extend(cls._expand_columns(inner_col)) + + return cols + + @classmethod + def get_columns( + cls, + inspector: Inspector, + table_name: str, + schema: str | None, + options: dict[str, Any] | None = None, + ) -> list[ResultSetColumnType]: + """ + If the "expand_rows" feature is enabled on the database via + "schema_options", expand the schema definition out to show all + subfields of nested ROWs as their appropriate dotted paths. + """ + base_cols = super().get_columns(inspector, table_name, schema, options) + if not (options or {}).get("expand_rows"): + return base_cols + + return [col for base_col in base_cols for col in cls._expand_columns(base_col)] diff --git a/superset/models/core.py b/superset/models/core.py index f6e4b972b48db..81e138aa30b2f 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -236,6 +236,11 @@ def disable_data_preview(self) -> bool: # this will prevent any 'trash value' strings from going through return self.get_extra().get("disable_data_preview", False) is True + @property + def schema_options(self) -> dict[str, Any]: + """Additional schema display config for engines with complex schemas""" + return self.get_extra().get("schema_options", {}) + @property def data(self) -> dict[str, Any]: return { @@ -247,6 +252,7 @@ def data(self) -> dict[str, Any]: "allows_cost_estimate": self.allows_cost_estimate, "allows_virtual_table_explore": self.allows_virtual_table_explore, "explore_database_id": self.explore_database_id, + "schema_options": self.schema_options, "parameters": self.parameters, "disable_data_preview": self.disable_data_preview, "parameters_schema": self.parameters_schema, @@ -837,7 +843,9 @@ def get_columns( self, table_name: str, schema: str | None = None ) -> list[ResultSetColumnType]: with self.get_inspector_with_context() as inspector: - return self.db_engine_spec.get_columns(inspector, table_name, schema) + return self.db_engine_spec.get_columns( + inspector, table_name, schema, self.schema_options + ) def get_metrics( self, diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 953683b5dcd01..c71dcea3f1a2d 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -84,6 +84,8 @@ class ResultSetColumnType(TypedDict): scale: NotRequired[Any] max_length: NotRequired[Any] + query_as: NotRequired[Any] + CacheConfig = dict[str, Any] DbapiDescriptionRow = tuple[ diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 1b50a683a0841..15e55fc5af62f 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access +import copy import json from datetime import datetime from typing import Any, Optional @@ -24,9 +25,11 @@ import pytest from pytest_mock import MockerFixture from sqlalchemy import types +from trino.sqlalchemy import datatype import superset.config from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT +from superset.superset_typing import ResultSetColumnType, SQLAColumnType from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( assert_column_spec, @@ -35,6 +38,24 @@ from tests.unit_tests.fixtures.common import dttm +def _assert_columns_equal(actual_cols, expected_cols) -> None: + """ + Assert equality of the given cols, bearing in mind sqlalchemy type + instances can't be compared for equality, so will have to be converted to + strings first. + """ + actual = copy.deepcopy(actual_cols) + expected = copy.deepcopy(expected_cols) + + for col in actual: + col["type"] = str(col["type"]) + + for col in expected: + col["type"] = str(col["type"]) + + assert actual == expected + + @pytest.mark.parametrize( "extra,expected", [ @@ -395,3 +416,104 @@ def _mock_execute(*args, **kwargs): mock_query.set_extra_json_key.assert_called_once_with( key=QUERY_CANCEL_KEY, value=query_id ) + + +def test_get_columns(mocker: MockerFixture): + """Test that ROW columns are not expanded without expand_rows""" + from superset.db_engine_specs.trino import TrinoEngineSpec + + field1_type = datatype.parse_sqltype("row(a varchar, b date)") + field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))") + field3_type = datatype.parse_sqltype("int") + + sqla_columns = [ + SQLAColumnType(name="field1", type=field1_type, is_dttm=False), + SQLAColumnType(name="field2", type=field2_type, is_dttm=False), + SQLAColumnType(name="field3", type=field3_type, is_dttm=False), + ] + mock_inspector = mocker.MagicMock() + mock_inspector.get_columns.return_value = sqla_columns + + actual = TrinoEngineSpec.get_columns(mock_inspector, "table", "schema") + expected = [ + ResultSetColumnType( + name="field1", column_name="field1", type=field1_type, is_dttm=False + ), + ResultSetColumnType( + name="field2", column_name="field2", type=field2_type, is_dttm=False + ), + ResultSetColumnType( + name="field3", column_name="field3", type=field3_type, is_dttm=False + ), + ] + + _assert_columns_equal(actual, expected) + + +def test_get_columns_expand_rows(mocker: MockerFixture): + """Test that ROW columns are correctly expanded with expand_rows""" + from superset.db_engine_specs.trino import TrinoEngineSpec + + field1_type = datatype.parse_sqltype("row(a varchar, b date)") + field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))") + field3_type = datatype.parse_sqltype("int") + + sqla_columns = [ + SQLAColumnType(name="field1", type=field1_type, is_dttm=False), + SQLAColumnType(name="field2", type=field2_type, is_dttm=False), + SQLAColumnType(name="field3", type=field3_type, is_dttm=False), + ] + mock_inspector = mocker.MagicMock() + mock_inspector.get_columns.return_value = sqla_columns + + actual = TrinoEngineSpec.get_columns( + mock_inspector, "table", "schema", {"expand_rows": True} + ) + expected = [ + ResultSetColumnType( + name="field1", column_name="field1", type=field1_type, is_dttm=False + ), + ResultSetColumnType( + name="field1.a", + column_name="field1.a", + type=types.VARCHAR(), + is_dttm=False, + query_as='"field1"."a" AS "field1.a"', + ), + ResultSetColumnType( + name="field1.b", + column_name="field1.b", + type=types.DATE(), + is_dttm=True, + query_as='"field1"."b" AS "field1.b"', + ), + ResultSetColumnType( + name="field2", column_name="field2", type=field2_type, is_dttm=False + ), + ResultSetColumnType( + name="field2.r1", + column_name="field2.r1", + type=datatype.parse_sqltype("row(a varchar, b varchar)"), + is_dttm=False, + query_as='"field2"."r1" AS "field2.r1"', + ), + ResultSetColumnType( + name="field2.r1.a", + column_name="field2.r1.a", + type=types.VARCHAR(), + is_dttm=False, + query_as='"field2"."r1"."a" AS "field2.r1.a"', + ), + ResultSetColumnType( + name="field2.r1.b", + column_name="field2.r1.b", + type=types.VARCHAR(), + is_dttm=False, + query_as='"field2"."r1"."b" AS "field2.r1.b"', + ), + ResultSetColumnType( + name="field3", column_name="field3", type=field3_type, is_dttm=False + ), + ] + + _assert_columns_equal(actual, expected)