Skip to content

Commit

Permalink
Merge pull request #69 from lyft/merge_apache_20190918a
Browse files Browse the repository at this point in the history
Merge apache 20190918a
  • Loading branch information
Beto Dealmeida authored Sep 18, 2019
2 parents f15343c + 06ee3e7 commit 9678836
Show file tree
Hide file tree
Showing 14 changed files with 120 additions and 22 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,4 @@ the world know they are using Superset. Join our growing community!
1. [Zaihang](http://www.zaih.com/)
1. [Zalando](https://www.zalando.com)
1. [Zalora](https://www.zalora.com)
1. [TME QQMUSIC/WESING](https://www.tencentmusic.com/)
3 changes: 3 additions & 0 deletions superset/assets/src/components/TableSelector.css
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@
border-bottom: 1px solid #f2f2f2;
margin: 15px 0;
}
.TableLabel {
white-space: nowrap;
}
2 changes: 1 addition & 1 deletion superset/assets/src/components/TableSelector.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ export default class TableSelector extends React.PureComponent {
onMouseEnter={() => focusOption(option)}
style={style}
>
<span>
<span className="TableLabel">
<span className="m-r-5">
<small className="text-muted">
<i className={`fa fa-${option.type === 'view' ? 'eye' : 'table'}`} />
Expand Down
8 changes: 7 additions & 1 deletion superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
try:
from pydruid.client import PyDruid
from pydruid.utils.aggregators import count
from pydruid.utils.dimensions import MapLookupExtraction, RegexExtraction
from pydruid.utils.dimensions import (
MapLookupExtraction,
RegexExtraction,
RegisteredLookupExtraction,
)
from pydruid.utils.filters import Dimension, Filter
from pydruid.utils.having import Aggregation
from pydruid.utils.postaggregator import (
Expand Down Expand Up @@ -1402,6 +1406,8 @@ def _create_extraction_fn(dim_spec):
)
elif ext_type == "regex":
extraction_fn = RegexExtraction(fn["expr"])
elif ext_type == "registeredLookup":
extraction_fn = RegisteredLookupExtraction(fn.get("lookup"))
else:
raise Exception(_("Unsupported extraction function: " + ext_type))
return (col, extraction_fn)
Expand Down
2 changes: 1 addition & 1 deletion superset/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, data, cursor_description, db_engine_spec):
# need to do this because we can not specify a mixed dtype when
# instantiating the DataFrame, and this allows us to have different
# dtypes for each column.
array = np.array(data)
array = np.array(data, dtype="object")
data = {
column: pd.Series(array[:, i], dtype=dtype[column])
for i, column in enumerate(column_names)
Expand Down
14 changes: 11 additions & 3 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import hashlib
import os
import re
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TYPE_CHECKING, Union

from flask import g
from flask_babel import lazy_gettext as _
Expand All @@ -40,6 +40,10 @@
from superset import app, db, sql_parse
from superset.utils import core as utils

if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database


class TimeGrain(NamedTuple):
name: str # TODO: redundant field, remove
Expand Down Expand Up @@ -538,7 +542,9 @@ def get_schema_names(cls, inspector: Inspector) -> List[str]:
return sorted(inspector.get_schema_names())

@classmethod
def get_table_names(cls, inspector: Inspector, schema: Optional[str]) -> List[str]:
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
) -> List[str]:
"""
Get all tables from schema
Expand All @@ -552,7 +558,9 @@ def get_table_names(cls, inspector: Inspector, schema: Optional[str]) -> List[st
return sorted(tables)

@classmethod
def get_view_names(cls, inspector: Inspector, schema: Optional[str]) -> List[str]:
def get_view_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
) -> List[str]:
"""
Get all views from schema
Expand Down
8 changes: 6 additions & 2 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
# under the License.
# pylint: disable=C,R,W
from datetime import datetime
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, TYPE_CHECKING

from sqlalchemy.dialects.postgresql.base import PGInspector

from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod

if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database


class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """
Expand Down Expand Up @@ -64,7 +68,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):

@classmethod
def get_table_names(
cls, inspector: PGInspector, schema: Optional[str]
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
) -> List[str]:
"""Need to consider foreign tables for PostgreSQL"""
tables = inspector.get_table_names(schema)
Expand Down
44 changes: 39 additions & 5 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import re
import textwrap
import time
from typing import Any, cast, Dict, List, Optional, Set, Tuple
from typing import Any, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
from urllib import parse

import simplejson as json
Expand All @@ -40,6 +40,10 @@
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils

if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database

QueryStatus = utils.QueryStatus
config = app.config

Expand All @@ -53,8 +57,8 @@
"real": "float64",
"double": "float64",
"varchar": "object",
"timestamp": "datetime64",
"date": "datetime64",
"timestamp": "datetime64[ns]",
"date": "datetime64[ns]",
"varbinary": "object",
}

Expand Down Expand Up @@ -128,14 +132,44 @@ def get_allow_cost_estimate(cls, version: str = None) -> bool:
return version is not None and StrictVersion(version) >= StrictVersion("0.319")

@classmethod
def get_view_names(cls, inspector: Inspector, schema: Optional[str]) -> List[str]:
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
) -> List[str]:
tables = super().get_table_names(database, inspector, schema)
if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"):
return tables

views = set(cls.get_view_names(database, inspector, schema))
actual_tables = set(tables) - views
return list(actual_tables)

@classmethod
def get_view_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
) -> List[str]:
"""Returns an empty list
get_table_names() function returns all table names and view names,
and get_view_names() is not implemented in sqlalchemy_presto.py
https://github.com/dropbox/PyHive/blob/e25fc8440a0686bbb7a5db5de7cb1a77bdb4167a/pyhive/sqlalchemy_presto.py
"""
return []
if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"):
return []

if schema:
sql = "SELECT table_name FROM information_schema.views WHERE table_schema=%(schema)s"
params = {"schema": schema}
else:
sql = "SELECT table_name FROM information_schema.views"
params = {}

engine = cls.get_engine(database, schema=schema)
with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor:
cursor.execute(sql, params)
results = cursor.fetchall()

return [row[0] for row in results]

@classmethod
def _create_column_info(cls, name: str, data_type: str) -> dict:
Expand Down
10 changes: 8 additions & 2 deletions superset/db_engine_specs/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
# under the License.
# pylint: disable=C,R,W
from datetime import datetime
from typing import List
from typing import List, TYPE_CHECKING

from sqlalchemy.engine.reflection import Inspector

from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils

if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database


class SqliteEngineSpec(BaseEngineSpec):
engine = "sqlite"
Expand Down Expand Up @@ -79,6 +83,8 @@ def convert_dttm(cls, target_type: str, dttm: datetime) -> str:
return "'{}'".format(iso)

@classmethod
def get_table_names(cls, inspector: Inspector, schema: str) -> List[str]:
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: str
) -> List[str]:
"""Need to disregard the schema for Sqlite"""
return sorted(inspector.get_table_names())
4 changes: 2 additions & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ def get_all_table_names_in_schema(
"""
try:
tables = self.db_engine_spec.get_table_names(
inspector=self.inspector, schema=schema
database=self, inspector=self.inspector, schema=schema
)
return [
utils.DatasourceName(table=table, schema=schema) for table in tables
Expand Down Expand Up @@ -1097,7 +1097,7 @@ def get_all_view_names_in_schema(
"""
try:
views = self.db_engine_spec.get_view_names(
inspector=self.inspector, schema=schema
database=self, inspector=self.inspector, schema=schema
)
return [utils.DatasourceName(table=view, schema=schema) for view in views]
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,7 @@ def get_datasource_label(ds_name: utils.DatasourceName) -> str:
for vn in views[:max_views]
]
)
table_options.sort(key=lambda value: value["label"])
payload = {"tableLength": len(tables) + len(views), "options": table_options}
return json_success(json.dumps(payload))

Expand Down
6 changes: 6 additions & 0 deletions tests/dataframe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,9 @@ def test_int64_with_missing_data(self):
cdf.raw_df.values.tolist(),
[[np.nan], [1239162456494753670], [np.nan], [np.nan], [np.nan], [np.nan]],
)

def test_pandas_datetime64(self):
data = [(None,)]
cursor_descr = [("ds", "timestamp", None, None, None, None, True)]
cdf = SupersetDataFrame(data, cursor_descr, PrestoEngineSpec)
self.assertEqual(cdf.raw_df.dtypes[0], np.dtype("<M8[ns]"))
12 changes: 8 additions & 4 deletions tests/db_engine_specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,9 @@ def test_engine_time_grain_validity(self):
self.assertSetEqual(defined_grains, intersection, engine)

def test_presto_get_view_names_return_empty_list(self):
self.assertEquals([], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY))
self.assertEquals(
[], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
)

def verify_presto_column(self, column, expected_results):
inspector = mock.Mock()
Expand Down Expand Up @@ -877,7 +879,9 @@ def test_presto_where_latest_partition(self):
self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result)

def test_hive_get_view_names_return_empty_list(self):
self.assertEquals([], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY))
self.assertEquals(
[], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
)

def test_bigquery_sqla_column_label(self):
label = BigQueryEngineSpec.make_label_compatible(column("Col").name)
Expand Down Expand Up @@ -952,15 +956,15 @@ def test_get_table_names(self):
ie. when try_remove_schema_from_table_name == True. """
base_result_expected = ["table", "table_2"]
base_result = BaseEngineSpec.get_table_names(
schema="schema", inspector=inspector
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(base_result_expected, base_result)

""" Make sure postgres doesn't try to remove schema name from table name
ie. when try_remove_schema_from_table_name == False. """
pg_result_expected = ["schema.table", "table_2", "table_3"]
pg_result = PostgresEngineSpec.get_table_names(
schema="schema", inspector=inspector
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(pg_result_expected, pg_result)

Expand Down
27 changes: 26 additions & 1 deletion tests/druid_func_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from unittest.mock import Mock

try:
from pydruid.utils.dimensions import MapLookupExtraction, RegexExtraction
from pydruid.utils.dimensions import (
MapLookupExtraction,
RegexExtraction,
RegisteredLookupExtraction,
)
import pydruid.utils.postaggregator as postaggs
except ImportError:
pass
Expand Down Expand Up @@ -110,6 +114,27 @@ def test_get_filters_extraction_fn_regex(self):
f_ext_fn = f.extraction_function
self.assertEqual(dim_ext_fn["expr"], f_ext_fn._expr)

@unittest.skipUnless(
SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
)
def test_get_filters_extraction_fn_registered_lookup_extraction(self):
filters = [{"col": "country", "val": ["Spain"], "op": "in"}]
dimension_spec = {
"type": "extraction",
"dimension": "country_name",
"outputName": "country",
"outputType": "STRING",
"extractionFn": {"type": "registeredLookup", "lookup": "country_name"},
}
spec_json = json.dumps(dimension_spec)
col = DruidColumn(column_name="country", dimension_spec_json=spec_json)
column_dict = {"country": col}
f = DruidDatasource.get_filters(filters, [], column_dict)
assert isinstance(f.extraction_function, RegisteredLookupExtraction)
dim_ext_fn = dimension_spec["extractionFn"]
self.assertEqual(dim_ext_fn["type"], f.extraction_function.extraction_type)
self.assertEqual(dim_ext_fn["lookup"], f.extraction_function._lookup)

@unittest.skipUnless(
SupersetTestCase.is_module_installed("pydruid"), "pydruid not installed"
)
Expand Down

0 comments on commit 9678836

Please sign in to comment.