Skip to content

Commit

Permalink
fix(dashboard): Return columns and verbose_map for groupby values of …
Browse files Browse the repository at this point in the history
…Pivot Table v2 [ID-7] (#17287)

* fix(dashboard): Return columns and verbose_map for groupby values of Pivot Table v2

* Refactor

* Fix test and lint

* Fix test

* Refactor

* Fix lint
  • Loading branch information
kgabryje authored Nov 5, 2021
1 parent ab1fcf3 commit fa51b32
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 14 deletions.
27 changes: 20 additions & 7 deletions superset/connectors/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,9 @@ def data(self) -> Dict[str, Any]:
"select_star": self.select_star,
}

def data_for_slices(self, slices: List[Slice]) -> Dict[str, Any]:
def data_for_slices( # pylint: disable=too-many-locals
self, slices: List[Slice]
) -> Dict[str, Any]:
"""
The representation of the datasource containing only the required data
to render the provided slices.
Expand Down Expand Up @@ -317,11 +319,23 @@ def data_for_slices(self, slices: List[Slice]) -> Dict[str, Any]:
if "column" in filter_config
)

column_names.update(
column
for column_param in COLUMN_FORM_DATA_PARAMS
for column in utils.get_iterable(form_data.get(column_param) or [])
)
# legacy charts don't have query_context charts
query_context = slc.get_query_context()
if query_context:
column_names.update(
[
column
for query in query_context.queries
for column in query.columns
]
or []
)
else:
column_names.update(
column
for column_param in COLUMN_FORM_DATA_PARAMS
for column in utils.get_iterable(form_data.get(column_param) or [])
)

filtered_metrics = [
metric
Expand Down Expand Up @@ -639,7 +653,6 @@ def data(self) -> Dict[str, Any]:


class BaseMetric(AuditMixinNullable, ImportExportMixin):

"""Interface for Metrics"""

__tablename__: Optional[str] = None # {connector_name}_metric
Expand Down
23 changes: 23 additions & 0 deletions superset/examples/birth_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
"markup_type": "markdown",
}

default_query_context = {
"result_format": "json",
"result_type": "full",
"datasource": {"id": tbl.id, "type": "table",},
"queries": [{"columns": [], "metrics": [],},],
}

admin = get_admin_user()
if admin_owner:
slice_props = dict(
Expand Down Expand Up @@ -362,6 +369,22 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[
metrics=metrics,
),
),
Slice(
**slice_props,
slice_name="Pivot Table v2",
viz_type="pivot_table_v2",
params=get_slice_json(
defaults,
viz_type="pivot_table_v2",
groupbyRows=["name"],
groupbyColumns=["state"],
metrics=[metric],
),
query_context=get_slice_json(
default_query_context,
queries=[{"columns": ["name", "state"], "metrics": [metric],}],
),
),
]
misc_slices = [
Slice(
Expand Down
13 changes: 13 additions & 0 deletions superset/models/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from superset.viz import BaseViz, viz_types

if TYPE_CHECKING:
from superset.common.query_context import QueryContext
from superset.connectors.base.models import BaseDatasource

metadata = Model.metadata # pylint: disable=no-member
Expand Down Expand Up @@ -247,6 +248,18 @@ def form_data(self) -> Dict[str, Any]:
update_time_range(form_data)
return form_data

def get_query_context(self) -> Optional["QueryContext"]:
# pylint: disable=import-outside-toplevel
from superset.common.query_context import QueryContext

This comment has been minimized.

Copy link
@ofekisr

ofekisr Nov 8, 2021

Contributor

you have added a circular dependency so it is not solid


if self.query_context:
try:
return QueryContext(**json.loads(self.query_context))

This comment has been minimized.

Copy link
@ofekisr

ofekisr Nov 8, 2021

Contributor

I don't see any validation regarding the context of the raw query context.
If validation is required (by schema ) on the data form taken from chart data API it should be here too

except json.decoder.JSONDecodeError as ex:
logger.error("Malformed json in slice's query context", exc_info=True)
logger.exception(ex)
return None

def get_explore_url(
self,
base_url: str = "/superset/explore",
Expand Down
6 changes: 3 additions & 3 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def test_get_charts(self):
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 33)
self.assertEqual(data["count"], 34)

def test_get_charts_changed_on(self):
"""
Expand Down Expand Up @@ -1040,7 +1040,7 @@ def test_get_charts_page(self):
"""
Chart API: Test get charts filter
"""
# Assuming we have 33 sample charts
# Assuming we have 34 sample charts
self.login(username="admin")
arguments = {"page_size": 10, "page": 0}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
Expand All @@ -1054,7 +1054,7 @@ def test_get_charts_page(self):
rv = self.get_assert_metric(uri, "get_list")
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(len(data["result"]), 3)
self.assertEqual(len(data["result"]), 4)

def test_get_charts_no_data_access(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,7 @@ def test_get_database_related_objects(self):
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["charts"]["count"], 33)
self.assertEqual(response["charts"]["count"], 34)
self.assertEqual(response["dashboards"]["count"], 3)

def test_get_database_related_objects_not_found(self):
Expand Down
32 changes: 29 additions & 3 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def test_query_with_non_existent_metrics(self):
self.assertTrue("Metric 'invalid' does not exist", context.exception)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_data_for_slices(self):
def test_data_for_slices_with_no_query_context(self):
tbl = self.get_table(name="birth_names")
slc = (
metadata_db.session.query(Slice)
Expand All @@ -532,9 +532,35 @@ def test_data_for_slices(self):
assert len(data_for_slices["columns"]) == 1
assert data_for_slices["metrics"][0]["metric_name"] == "sum__num"
assert data_for_slices["columns"][0]["column_name"] == "gender"
assert set(data_for_slices["verbose_map"].keys()) == set(
["__timestamp", "sum__num", "gender",]
assert set(data_for_slices["verbose_map"].keys()) == {
"__timestamp",
"sum__num",
"gender",
}

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_data_for_slices_with_query_context(self):
tbl = self.get_table(name="birth_names")
slc = (
metadata_db.session.query(Slice)
.filter_by(
datasource_id=tbl.id,
datasource_type=tbl.type,
slice_name="Pivot Table v2",
)
.first()
)
data_for_slices = tbl.data_for_slices([slc])
assert len(data_for_slices["metrics"]) == 1
assert len(data_for_slices["columns"]) == 2
assert data_for_slices["metrics"][0]["metric_name"] == "sum__num"
assert data_for_slices["columns"][0]["column_name"] == "name"
assert set(data_for_slices["verbose_map"].keys()) == {
"__timestamp",
"sum__num",
"name",
"state",
}


def test_literal_dttm_type_factory():
Expand Down

0 comments on commit fa51b32

Please sign in to comment.