From 0cd4ef54f884e77cd4c5a810f10697b33e9f6e56 Mon Sep 17 00:00:00 2001 From: 1AB9502 Date: Wed, 18 Sep 2019 00:45:47 -0600 Subject: [PATCH 1/5] Add RegisteredLookupExtraction support to extraction function (#8185) * Add RegisteredLookupExtraction support to extraction function * Fix formatting issues * Reformat druid_func_tests through black --- superset/connectors/druid/models.py | 8 +++++++- tests/druid_func_tests.py | 27 ++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index f204a90c9b18c..b908056aaed62 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -35,7 +35,11 @@ try: from pydruid.client import PyDruid from pydruid.utils.aggregators import count - from pydruid.utils.dimensions import MapLookupExtraction, RegexExtraction + from pydruid.utils.dimensions import ( + MapLookupExtraction, + RegexExtraction, + RegisteredLookupExtraction, + ) from pydruid.utils.filters import Dimension, Filter from pydruid.utils.having import Aggregation from pydruid.utils.postaggregator import ( @@ -1402,6 +1406,8 @@ def _create_extraction_fn(dim_spec): ) elif ext_type == "regex": extraction_fn = RegexExtraction(fn["expr"]) + elif ext_type == "registeredLookup": + extraction_fn = RegisteredLookupExtraction(fn.get("lookup")) else: raise Exception(_("Unsupported extraction function: " + ext_type)) return (col, extraction_fn) diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py index c4d6ab92c3010..460595dd178ab 100644 --- a/tests/druid_func_tests.py +++ b/tests/druid_func_tests.py @@ -19,7 +19,11 @@ from unittest.mock import Mock try: - from pydruid.utils.dimensions import MapLookupExtraction, RegexExtraction + from pydruid.utils.dimensions import ( + MapLookupExtraction, + RegexExtraction, + RegisteredLookupExtraction, + ) import pydruid.utils.postaggregator as postaggs except ImportError: pass @@ -110,6 +114,27 @@ def test_get_filters_extraction_fn_regex(self): f_ext_fn = f.extraction_function self.assertEqual(dim_ext_fn["expr"], f_ext_fn._expr) + @unittest.skipUnless( + SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" + ) + def test_get_filters_extraction_fn_registered_lookup_extraction(self): + filters = [{"col": "country", "val": ["Spain"], "op": "in"}] + dimension_spec = { + "type": "extraction", + "dimension": "country_name", + "outputName": "country", + "outputType": "STRING", + "extractionFn": {"type": "registeredLookup", "lookup": "country_name"}, + } + spec_json = json.dumps(dimension_spec) + col = DruidColumn(column_name="country", dimension_spec_json=spec_json) + column_dict = {"country": col} + f = DruidDatasource.get_filters(filters, [], column_dict) + assert isinstance(f.extraction_function, RegisteredLookupExtraction) + dim_ext_fn = dimension_spec["extractionFn"] + self.assertEqual(dim_ext_fn["type"], f.extraction_function.extraction_type) + self.assertEqual(dim_ext_fn["lookup"], f.extraction_function._lookup) + @unittest.skipUnless( SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed" ) From 8d04e1f55fc8039b7134806ca15a62f7cbe5c7a5 Mon Sep 17 00:00:00 2001 From: MaiTiano Date: Wed, 18 Sep 2019 23:15:47 +0800 Subject: [PATCH 2/5] Update README.md (#8246) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 1d4a45efe97c6..194e7f0f19f8c 100644 --- a/README.md +++ b/README.md @@ -212,3 +212,4 @@ the world know they are using Superset. Join our growing community! 1. [Zaihang](http://www.zaih.com/) 1. [Zalando](https://www.zalando.com) 1. [Zalora](https://www.zalora.com) + 1. [TME QQMUSIC/WESING](https://www.tencentmusic.com/) From 4088a84eb4d83c54cde1fbfc25c2e68fa33c399f Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 18 Sep 2019 12:46:50 -0700 Subject: [PATCH 3/5] Small fix for Presto dtype map (#8251) * Small fix for Presto dtype map * Add unit test --- superset/db_engine_specs/presto.py | 4 ++-- tests/dataframe_test.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index c060da56f4609..6e28bd1e12a2b 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -53,8 +53,8 @@ "real": "float64", "double": "float64", "varchar": "object", - "timestamp": "datetime64", - "date": "datetime64", + "timestamp": "datetime64[ns]", + "date": "datetime64[ns]", "varbinary": "object", } diff --git a/tests/dataframe_test.py b/tests/dataframe_test.py index 6b421e919c56a..d254d63a040bb 100644 --- a/tests/dataframe_test.py +++ b/tests/dataframe_test.py @@ -129,3 +129,9 @@ def test_int64_with_missing_data(self): cdf.raw_df.values.tolist(), [[np.nan], [1239162456494753670], [np.nan], [np.nan], [np.nan], [np.nan]], ) + + def test_pandas_datetime64(self): + data = [(None,)] + cursor_descr = [("ds", "timestamp", None, None, None, None, True)] + cdf = SupersetDataFrame(data, cursor_descr, PrestoEngineSpec) + self.assertEqual(cdf.raw_df.dtypes[0], np.dtype(" Date: Wed, 18 Sep 2019 12:47:10 -0700 Subject: [PATCH 4/5] Show Presto views as views, not tables (#8243) * WIP * Implement views in Presto * Clean up * Fix CSS * Fix unit tests * Add types to database * Fix circular import --- .../assets/src/components/TableSelector.css | 3 ++ .../assets/src/components/TableSelector.jsx | 2 +- superset/db_engine_specs/base.py | 14 +++++-- superset/db_engine_specs/postgres.py | 8 +++- superset/db_engine_specs/presto.py | 40 +++++++++++++++++-- superset/db_engine_specs/sqlite.py | 10 ++++- superset/models/core.py | 4 +- superset/views/core.py | 1 + tests/db_engine_specs_test.py | 12 ++++-- 9 files changed, 77 insertions(+), 17 deletions(-) diff --git a/superset/assets/src/components/TableSelector.css b/superset/assets/src/components/TableSelector.css index f4098c14056c4..078e57df55681 100644 --- a/superset/assets/src/components/TableSelector.css +++ b/superset/assets/src/components/TableSelector.css @@ -36,3 +36,6 @@ border-bottom: 1px solid #f2f2f2; margin: 15px 0; } +.TableLabel { + white-space: nowrap; +} diff --git a/superset/assets/src/components/TableSelector.jsx b/superset/assets/src/components/TableSelector.jsx index d0785d61af084..1b83acc32b497 100644 --- a/superset/assets/src/components/TableSelector.jsx +++ b/superset/assets/src/components/TableSelector.jsx @@ -221,7 +221,7 @@ export default class TableSelector extends React.PureComponent { onMouseEnter={() => focusOption(option)} style={style} > - + diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index b71966ce12f33..bd7b1d1de099d 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -20,7 +20,7 @@ import hashlib import os import re -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TYPE_CHECKING, Union from flask import g from flask_babel import lazy_gettext as _ @@ -40,6 +40,10 @@ from superset import app, db, sql_parse from superset.utils import core as utils +if TYPE_CHECKING: + # prevent circular imports + from superset.models.core import Database + class TimeGrain(NamedTuple): name: str # TODO: redundant field, remove @@ -538,7 +542,9 @@ def get_schema_names(cls, inspector: Inspector) -> List[str]: return sorted(inspector.get_schema_names()) @classmethod - def get_table_names(cls, inspector: Inspector, schema: Optional[str]) -> List[str]: + def get_table_names( + cls, database: "Database", inspector: Inspector, schema: Optional[str] + ) -> List[str]: """ Get all tables from schema @@ -552,7 +558,9 @@ def get_table_names(cls, inspector: Inspector, schema: Optional[str]) -> List[st return sorted(tables) @classmethod - def get_view_names(cls, inspector: Inspector, schema: Optional[str]) -> List[str]: + def get_view_names( + cls, database: "Database", inspector: Inspector, schema: Optional[str] + ) -> List[str]: """ Get all views from schema diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 4716b07586ac8..5b8988021d4a0 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -16,12 +16,16 @@ # under the License. # pylint: disable=C,R,W from datetime import datetime -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, TYPE_CHECKING from sqlalchemy.dialects.postgresql.base import PGInspector from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod +if TYPE_CHECKING: + # prevent circular imports + from superset.models.core import Database + class PostgresBaseEngineSpec(BaseEngineSpec): """ Abstract class for Postgres 'like' databases """ @@ -64,7 +68,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec): @classmethod def get_table_names( - cls, inspector: PGInspector, schema: Optional[str] + cls, database: "Database", inspector: PGInspector, schema: Optional[str] ) -> List[str]: """Need to consider foreign tables for PostgreSQL""" tables = inspector.get_table_names(schema) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 6e28bd1e12a2b..ddc6442ac1937 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -23,7 +23,7 @@ import re import textwrap import time -from typing import Any, cast, Dict, List, Optional, Set, Tuple +from typing import Any, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING from urllib import parse import simplejson as json @@ -40,6 +40,10 @@ from superset.sql_parse import ParsedQuery from superset.utils import core as utils +if TYPE_CHECKING: + # prevent circular imports + from superset.models.core import Database + QueryStatus = utils.QueryStatus config = app.config @@ -128,14 +132,44 @@ def get_allow_cost_estimate(cls, version: str = None) -> bool: return version is not None and StrictVersion(version) >= StrictVersion("0.319") @classmethod - def get_view_names(cls, inspector: Inspector, schema: Optional[str]) -> List[str]: + def get_table_names( + cls, database: "Database", inspector: Inspector, schema: Optional[str] + ) -> List[str]: + tables = super().get_table_names(database, inspector, schema) + if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"): + return tables + + views = set(cls.get_view_names(database, inspector, schema)) + actual_tables = set(tables) - views + return list(actual_tables) + + @classmethod + def get_view_names( + cls, database: "Database", inspector: Inspector, schema: Optional[str] + ) -> List[str]: """Returns an empty list get_table_names() function returns all table names and view names, and get_view_names() is not implemented in sqlalchemy_presto.py https://github.com/dropbox/PyHive/blob/e25fc8440a0686bbb7a5db5de7cb1a77bdb4167a/pyhive/sqlalchemy_presto.py """ - return [] + if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"): + return [] + + if schema: + sql = "SELECT table_name FROM information_schema.views WHERE table_schema=%(schema)s" + params = {"schema": schema} + else: + sql = "SELECT table_name FROM information_schema.views" + params = {} + + engine = cls.get_engine(database, schema=schema) + with closing(engine.raw_connection()) as conn: + with closing(conn.cursor()) as cursor: + cursor.execute(sql, params) + results = cursor.fetchall() + + return [row[0] for row in results] @classmethod def _create_column_info(cls, name: str, data_type: str) -> dict: diff --git a/superset/db_engine_specs/sqlite.py b/superset/db_engine_specs/sqlite.py index 28c8843057cfc..ff7074b340975 100644 --- a/superset/db_engine_specs/sqlite.py +++ b/superset/db_engine_specs/sqlite.py @@ -16,13 +16,17 @@ # under the License. # pylint: disable=C,R,W from datetime import datetime -from typing import List +from typing import List, TYPE_CHECKING from sqlalchemy.engine.reflection import Inspector from superset.db_engine_specs.base import BaseEngineSpec from superset.utils import core as utils +if TYPE_CHECKING: + # prevent circular imports + from superset.models.core import Database + class SqliteEngineSpec(BaseEngineSpec): engine = "sqlite" @@ -79,6 +83,8 @@ def convert_dttm(cls, target_type: str, dttm: datetime) -> str: return "'{}'".format(iso) @classmethod - def get_table_names(cls, inspector: Inspector, schema: str) -> List[str]: + def get_table_names( + cls, database: "Database", inspector: Inspector, schema: str + ) -> List[str]: """Need to disregard the schema for Sqlite""" return sorted(inspector.get_table_names()) diff --git a/superset/models/core.py b/superset/models/core.py index b31fb6710c5dd..200af496431b7 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -1063,7 +1063,7 @@ def get_all_table_names_in_schema( """ try: tables = self.db_engine_spec.get_table_names( - inspector=self.inspector, schema=schema + database=self, inspector=self.inspector, schema=schema ) return [ utils.DatasourceName(table=table, schema=schema) for table in tables @@ -1097,7 +1097,7 @@ def get_all_view_names_in_schema( """ try: views = self.db_engine_spec.get_view_names( - inspector=self.inspector, schema=schema + database=self, inspector=self.inspector, schema=schema ) return [utils.DatasourceName(table=view, schema=schema) for view in views] except Exception as e: diff --git a/superset/views/core.py b/superset/views/core.py index 36fb3ca7727fc..ad6c3babf57e6 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1572,6 +1572,7 @@ def get_datasource_label(ds_name: utils.DatasourceName) -> str: for vn in views[:max_views] ] ) + table_options.sort(key=lambda value: value["label"]) payload = {"tableLength": len(tables) + len(views), "options": table_options} return json_success(json.dumps(payload)) diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index f6354c0eeac0d..ec6ff24f228aa 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -342,7 +342,9 @@ def test_engine_time_grain_validity(self): self.assertSetEqual(defined_grains, intersection, engine) def test_presto_get_view_names_return_empty_list(self): - self.assertEquals([], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY)) + self.assertEquals( + [], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY) + ) def verify_presto_column(self, column, expected_results): inspector = mock.Mock() @@ -877,7 +879,9 @@ def test_presto_where_latest_partition(self): self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result) def test_hive_get_view_names_return_empty_list(self): - self.assertEquals([], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY)) + self.assertEquals( + [], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY) + ) def test_bigquery_sqla_column_label(self): label = BigQueryEngineSpec.make_label_compatible(column("Col").name) @@ -952,7 +956,7 @@ def test_get_table_names(self): ie. when try_remove_schema_from_table_name == True. """ base_result_expected = ["table", "table_2"] base_result = BaseEngineSpec.get_table_names( - schema="schema", inspector=inspector + database=mock.ANY, schema="schema", inspector=inspector ) self.assertListEqual(base_result_expected, base_result) @@ -960,7 +964,7 @@ def test_get_table_names(self): ie. when try_remove_schema_from_table_name == False. """ pg_result_expected = ["schema.table", "table_2", "table_3"] pg_result = PostgresEngineSpec.get_table_names( - schema="schema", inspector=inspector + database=mock.ANY, schema="schema", inspector=inspector ) self.assertListEqual(pg_result_expected, pg_result) From 8e1fc2b0ba665c35314eadf9a5196cd35bb2a433 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 18 Sep 2019 13:32:58 -0700 Subject: [PATCH 5/5] Fix array casting (#8253) --- superset/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superset/dataframe.py b/superset/dataframe.py index 47683b0cc0d44..c1733362d8073 100644 --- a/superset/dataframe.py +++ b/superset/dataframe.py @@ -110,7 +110,7 @@ def __init__(self, data, cursor_description, db_engine_spec): # need to do this because we can not specify a mixed dtype when # instantiating the DataFrame, and this allows us to have different # dtypes for each column. - array = np.array(data) + array = np.array(data, dtype="object") data = { column: pd.Series(array[:, i], dtype=dtype[column]) for i, column in enumerate(column_names)