From 512cfd8d5faa84e5db09c9acf0ce8a79950b269b Mon Sep 17 00:00:00 2001 From: John Bodley Date: Fri, 16 Dec 2022 07:58:11 +1300 Subject: [PATCH] chore: Re-add inheritance of Presto macros for Trino --- superset/db_engine_specs/presto.py | 411 +++++++++--------- superset/db_engine_specs/trino.py | 33 +- .../db_engine_specs/presto_tests.py | 3 +- .../db_engine_specs/trino_tests.py | 16 + 4 files changed, 253 insertions(+), 210 deletions(-) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 675503973485a..2a3acb8bb5f1a 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -311,6 +311,203 @@ def get_function_names(cls, database: Database) -> List[str]: """ return database.get_df("SHOW FUNCTIONS")["Function"].tolist() + @classmethod + def _partition_query( # pylint: disable=too-many-arguments,too-many-locals + cls, + table_name: str, + database: Database, + limit: int = 0, + order_by: Optional[List[Tuple[str, bool]]] = None, + filters: Optional[Dict[Any, Any]] = None, + ) -> str: + """Returns a partition query + + :param table_name: the name of the table to get partitions from + :type table_name: str + :param limit: the number of partitions to be returned + :type limit: int + :param order_by: a list of tuples of field name and a boolean + that determines if that field should be sorted in descending + order + :type order_by: list of (str, bool) tuples + :param filters: dict of field name and filter value combinations + """ + limit_clause = "LIMIT {}".format(limit) if limit else "" + order_by_clause = "" + if order_by: + l = [] + for field, desc in order_by: + l.append(field + " DESC" if desc else "") + order_by_clause = "ORDER BY " + ", ".join(l) + + where_clause = "" + if filters: + l = [] + for field, value in filters.items(): + l.append(f"{field} = '{value}'") + where_clause = "WHERE " + " AND ".join(l) + + presto_version = database.get_extra().get("version") + + # Partition select syntax changed in v0.199, so check here. + # Default to the new syntax if version is unset. + partition_select_clause = ( + f'SELECT * FROM "{table_name}$partitions"' + if not presto_version + or StrictVersion(presto_version) >= StrictVersion("0.199") + else f"SHOW PARTITIONS FROM {table_name}" + ) + + sql = dedent( + f"""\ + {partition_select_clause} + {where_clause} + {order_by_clause} + {limit_clause} + """ + ) + return sql + + @classmethod + def where_latest_partition( # pylint: disable=too-many-arguments + cls, + table_name: str, + schema: Optional[str], + database: Database, + query: Select, + columns: Optional[List[Dict[str, str]]] = None, + ) -> Optional[Select]: + try: + col_names, values = cls.latest_partition( + table_name, schema, database, show_first=True + ) + except Exception: # pylint: disable=broad-except + # table is not partitioned + return None + + if values is None: + return None + + column_type_by_name = { + column.get("name"): column.get("type") for column in columns or [] + } + + for col_name, value in zip(col_names, values): + if col_name in column_type_by_name: + if column_type_by_name.get(col_name) == "TIMESTAMP": + query = query.where(Column(col_name, TimeStamp()) == value) + elif column_type_by_name.get(col_name) == "DATE": + query = query.where(Column(col_name, Date()) == value) + else: + query = query.where(Column(col_name) == value) + return query + + @classmethod + def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]: + if not df.empty: + return df.to_records(index=False)[0].item() + return None + + @classmethod + @cache_manager.data_cache.memoize(timeout=60) + def latest_partition( + cls, + table_name: str, + schema: Optional[str], + database: Database, + show_first: bool = False, + ) -> Tuple[List[str], Optional[List[str]]]: + """Returns col name and the latest (max) partition value for a table + + :param table_name: the name of the table + :param schema: schema / database / namespace + :param database: database query will be run against + :type database: models.Database + :param show_first: displays the value for the first partitioning key + if there are many partitioning keys + :type show_first: bool + + >>> latest_partition('foo_table') + (['ds'], ('2018-01-01',)) + """ + indexes = database.get_indexes(table_name, schema) + if not indexes: + raise SupersetTemplateException( + f"Error getting partition for {schema}.{table_name}. " + "Verify that this table has a partition." + ) + + if len(indexes[0]["column_names"]) < 1: + raise SupersetTemplateException( + "The table should have one partitioned field" + ) + + if not show_first and len(indexes[0]["column_names"]) > 1: + raise SupersetTemplateException( + "The table should have a single partitioned field " + "to use this function. You may want to use " + "`presto.latest_sub_partition`" + ) + + column_names = indexes[0]["column_names"] + part_fields = [(column_name, True) for column_name in column_names] + sql = cls._partition_query(table_name, database, 1, part_fields) + df = database.get_df(sql, schema) + return column_names, cls._latest_partition_from_df(df) + + @classmethod + def latest_sub_partition( + cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any + ) -> Any: + """Returns the latest (max) partition value for a table + + A filtering criteria should be passed for all fields that are + partitioned except for the field to be returned. For example, + if a table is partitioned by (``ds``, ``event_type`` and + ``event_category``) and you want the latest ``ds``, you'll want + to provide a filter as keyword arguments for both + ``event_type`` and ``event_category`` as in + ``latest_sub_partition('my_table', + event_category='page', event_type='click')`` + + :param table_name: the name of the table, can be just the table + name or a fully qualified table name as ``schema_name.table_name`` + :type table_name: str + :param schema: schema / database / namespace + :type schema: str + :param database: database query will be run against + :type database: models.Database + + :param kwargs: keyword arguments define the filtering criteria + on the partition list. There can be many of these. + :type kwargs: str + >>> latest_sub_partition('sub_partition_table', event_type='click') + '2018-01-01' + """ + indexes = database.get_indexes(table_name, schema) + part_fields = indexes[0]["column_names"] + for k in kwargs.keys(): # pylint: disable=consider-iterating-dictionary + if k not in k in part_fields: # pylint: disable=comparison-with-itself + msg = f"Field [{k}] is not part of the portioning key" + raise SupersetTemplateException(msg) + if len(kwargs.keys()) != len(part_fields) - 1: + msg = ( + "A filter needs to be specified for {} out of the " "{} fields." + ).format(len(part_fields) - 1, len(part_fields)) + raise SupersetTemplateException(msg) + + for field in part_fields: + if field not in kwargs.keys(): + field_to_return = field + + sql = cls._partition_query( + table_name, database, 1, [(field_to_return, True)], kwargs + ) + df = database.get_df(sql, schema) + if df.empty: + return "" + return df.to_dict()[field_to_return][0] + class PrestoEngineSpec(PrestoBaseEngineSpec): engine = "presto" @@ -958,21 +1155,24 @@ def extra_table_metadata( indexes = database.get_indexes(table_name, schema_name) if indexes: - cols = indexes[0].get("column_names", []) - full_table_name = table_name - if schema_name and "." not in table_name: - full_table_name = "{}.{}".format(schema_name, table_name) - pql = cls._partition_query(full_table_name, database) col_names, latest_parts = cls.latest_partition( table_name, schema_name, database, show_first=True ) if not latest_parts: latest_parts = tuple([None] * len(col_names)) + metadata["partitions"] = { - "cols": cols, + "cols": sorted(indexes[0].get("column_names", [])), "latest": dict(zip(col_names, latest_parts)), - "partitionQuery": pql, + "partitionQuery": cls._partition_query( + table_name=( + f"{schema_name}.{table_name}" + if schema_name and "." not in table_name + else table_name + ), + database=database, + ), } # flake8 is not matching `Optional[str]` to `Any` for some reason... @@ -1085,203 +1285,6 @@ def _extract_error_message(cls, ex: Exception) -> str: return error_dict.get("message", _("Unknown Presto Error")) return utils.error_msg_from_exception(ex) - @classmethod - def _partition_query( # pylint: disable=too-many-arguments,too-many-locals - cls, - table_name: str, - database: Database, - limit: int = 0, - order_by: Optional[List[Tuple[str, bool]]] = None, - filters: Optional[Dict[Any, Any]] = None, - ) -> str: - """Returns a partition query - - :param table_name: the name of the table to get partitions from - :type table_name: str - :param limit: the number of partitions to be returned - :type limit: int - :param order_by: a list of tuples of field name and a boolean - that determines if that field should be sorted in descending - order - :type order_by: list of (str, bool) tuples - :param filters: dict of field name and filter value combinations - """ - limit_clause = "LIMIT {}".format(limit) if limit else "" - order_by_clause = "" - if order_by: - l = [] - for field, desc in order_by: - l.append(field + " DESC" if desc else "") - order_by_clause = "ORDER BY " + ", ".join(l) - - where_clause = "" - if filters: - l = [] - for field, value in filters.items(): - l.append(f"{field} = '{value}'") - where_clause = "WHERE " + " AND ".join(l) - - presto_version = database.get_extra().get("version") - - # Partition select syntax changed in v0.199, so check here. - # Default to the new syntax if version is unset. - partition_select_clause = ( - f'SELECT * FROM "{table_name}$partitions"' - if not presto_version - or StrictVersion(presto_version) >= StrictVersion("0.199") - else f"SHOW PARTITIONS FROM {table_name}" - ) - - sql = dedent( - f"""\ - {partition_select_clause} - {where_clause} - {order_by_clause} - {limit_clause} - """ - ) - return sql - - @classmethod - def where_latest_partition( # pylint: disable=too-many-arguments - cls, - table_name: str, - schema: Optional[str], - database: Database, - query: Select, - columns: Optional[List[Dict[str, str]]] = None, - ) -> Optional[Select]: - try: - col_names, values = cls.latest_partition( - table_name, schema, database, show_first=True - ) - except Exception: # pylint: disable=broad-except - # table is not partitioned - return None - - if values is None: - return None - - column_type_by_name = { - column.get("name"): column.get("type") for column in columns or [] - } - - for col_name, value in zip(col_names, values): - if col_name in column_type_by_name: - if column_type_by_name.get(col_name) == "TIMESTAMP": - query = query.where(Column(col_name, TimeStamp()) == value) - elif column_type_by_name.get(col_name) == "DATE": - query = query.where(Column(col_name, Date()) == value) - else: - query = query.where(Column(col_name) == value) - return query - - @classmethod - def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]: - if not df.empty: - return df.to_records(index=False)[0].item() - return None - - @classmethod - @cache_manager.data_cache.memoize(timeout=60) - def latest_partition( - cls, - table_name: str, - schema: Optional[str], - database: Database, - show_first: bool = False, - ) -> Tuple[List[str], Optional[List[str]]]: - """Returns col name and the latest (max) partition value for a table - - :param table_name: the name of the table - :param schema: schema / database / namespace - :param database: database query will be run against - :type database: models.Database - :param show_first: displays the value for the first partitioning key - if there are many partitioning keys - :type show_first: bool - - >>> latest_partition('foo_table') - (['ds'], ('2018-01-01',)) - """ - indexes = database.get_indexes(table_name, schema) - if not indexes: - raise SupersetTemplateException( - f"Error getting partition for {schema}.{table_name}. " - "Verify that this table has a partition." - ) - - if len(indexes[0]["column_names"]) < 1: - raise SupersetTemplateException( - "The table should have one partitioned field" - ) - - if not show_first and len(indexes[0]["column_names"]) > 1: - raise SupersetTemplateException( - "The table should have a single partitioned field " - "to use this function. You may want to use " - "`presto.latest_sub_partition`" - ) - - column_names = indexes[0]["column_names"] - part_fields = [(column_name, True) for column_name in column_names] - sql = cls._partition_query(table_name, database, 1, part_fields) - df = database.get_df(sql, schema) - return column_names, cls._latest_partition_from_df(df) - - @classmethod - def latest_sub_partition( - cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any - ) -> Any: - """Returns the latest (max) partition value for a table - - A filtering criteria should be passed for all fields that are - partitioned except for the field to be returned. For example, - if a table is partitioned by (``ds``, ``event_type`` and - ``event_category``) and you want the latest ``ds``, you'll want - to provide a filter as keyword arguments for both - ``event_type`` and ``event_category`` as in - ``latest_sub_partition('my_table', - event_category='page', event_type='click')`` - - :param table_name: the name of the table, can be just the table - name or a fully qualified table name as ``schema_name.table_name`` - :type table_name: str - :param schema: schema / database / namespace - :type schema: str - :param database: database query will be run against - :type database: models.Database - - :param kwargs: keyword arguments define the filtering criteria - on the partition list. There can be many of these. - :type kwargs: str - >>> latest_sub_partition('sub_partition_table', event_type='click') - '2018-01-01' - """ - indexes = database.get_indexes(table_name, schema) - part_fields = indexes[0]["column_names"] - for k in kwargs.keys(): # pylint: disable=consider-iterating-dictionary - if k not in k in part_fields: # pylint: disable=comparison-with-itself - msg = f"Field [{k}] is not part of the portioning key" - raise SupersetTemplateException(msg) - if len(kwargs.keys()) != len(part_fields) - 1: - msg = ( - "A filter needs to be specified for {} out of the " "{} fields." - ).format(len(part_fields) - 1, len(part_fields)) - raise SupersetTemplateException(msg) - - for field in part_fields: - if field not in kwargs.keys(): - field_to_return = field - - sql = cls._partition_query( - table_name, database, 1, [(field_to_return, True)], kwargs - ) - df = database.get_df(sql, schema) - if df.empty: - return "" - return df.to_dict()[field_to_return][0] - @classmethod def get_column_spec( cls, diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index c6faf6db6c9b9..2a1d8cc63984b 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -83,11 +83,34 @@ def extra_table_metadata( indexes = database.get_indexes(table_name, schema_name) if indexes: - partitions_columns = [] - for index in indexes: - if index.get("name") == "partition": - partitions_columns += index.get("column_names", []) - metadata["partitions"] = {"cols": partitions_columns} + col_names, latest_parts = cls.latest_partition( + table_name, schema_name, database, show_first=True + ) + + if not latest_parts: + latest_parts = tuple([None] * len(col_names)) + + metadata["partitions"] = { + "cols": sorted( + list( + set( + column_name + for index in indexes + if index.get("name") == "partition" + for column_name in index.get("column_names", []) + ) + ) + ), + "latest": dict(zip(col_names, latest_parts)), + "partitionQuery": cls._partition_query( + table_name=( + f"{schema_name}.{table_name}" + if schema_name and "." not in table_name + else table_name + ), + database=database, + ), + } if database.has_view_by_name(table_name, schema_name): metadata["view"] = database.inspector.get_view_definition( diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 4a76d59a46faf..eef3bb8d3625e 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -492,7 +492,8 @@ def test_presto_extra_table_metadata(self): db.get_df = mock.Mock(return_value=df) PrestoEngineSpec.get_create_view = mock.Mock(return_value=None) result = PrestoEngineSpec.extra_table_metadata(db, "test_table", "test_schema") - self.assertEqual({"ds": "01-01-19", "hour": 1}, result["partitions"]["latest"]) + assert result["partitions"]["cols"] == ["ds", "hour"] + assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1} def test_presto_where_latest_partition(self): db = mock.Mock() diff --git a/tests/integration_tests/db_engine_specs/trino_tests.py b/tests/integration_tests/db_engine_specs/trino_tests.py index 41a4f4e0f38da..6379d013b2f09 100644 --- a/tests/integration_tests/db_engine_specs/trino_tests.py +++ b/tests/integration_tests/db_engine_specs/trino_tests.py @@ -16,8 +16,10 @@ # under the License. import json from typing import Any, Dict +from unittest import mock from unittest.mock import Mock, patch +import pandas as pd import pytest from sqlalchemy import types @@ -196,3 +198,17 @@ def test_convert_dttm(self): TrinoEngineSpec.convert_dttm("DATE", dttm), "DATE '2019-01-02'", ) + + def test_extra_table_metadata(self): + db = mock.Mock() + db.get_indexes = mock.Mock( + return_value=[{"column_names": ["ds", "hour"], "name": "partition"}] + ) + db.get_extra = mock.Mock(return_value={}) + db.has_view_by_name = mock.Mock(return_value=None) + db.get_df = mock.Mock( + return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]}) + ) + result = TrinoEngineSpec.extra_table_metadata(db, "test_table", "test_schema") + assert result["partitions"]["cols"] == ["ds", "hour"] + assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}