diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 0bd94885a2479..95e205494c564 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -282,7 +282,9 @@ def data(self) -> Dict[str, Any]: "select_star": self.select_star, } - def data_for_slices(self, slices: List[Slice]) -> Dict[str, Any]: + def data_for_slices( # pylint: disable=too-many-locals + self, slices: List[Slice] + ) -> Dict[str, Any]: """ The representation of the datasource containing only the required data to render the provided slices. @@ -317,11 +319,23 @@ def data_for_slices(self, slices: List[Slice]) -> Dict[str, Any]: if "column" in filter_config ) - column_names.update( - column - for column_param in COLUMN_FORM_DATA_PARAMS - for column in utils.get_iterable(form_data.get(column_param) or []) - ) + # legacy charts don't have query_context charts + query_context = slc.get_query_context() + if query_context: + column_names.update( + [ + column + for query in query_context.queries + for column in query.columns + ] + or [] + ) + else: + column_names.update( + column + for column_param in COLUMN_FORM_DATA_PARAMS + for column in utils.get_iterable(form_data.get(column_param) or []) + ) filtered_metrics = [ metric @@ -639,7 +653,6 @@ def data(self) -> Dict[str, Any]: class BaseMetric(AuditMixinNullable, ImportExportMixin): - """Interface for Metrics""" __tablename__: Optional[str] = None # {connector_name}_metric diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index 4a4da1cc74917..ef964e298be17 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -184,6 +184,13 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[ "markup_type": "markdown", } + default_query_context = { + "result_format": "json", + "result_type": "full", + "datasource": {"id": tbl.id, "type": "table",}, + "queries": [{"columns": [], "metrics": [],},], + } + admin = get_admin_user() if admin_owner: slice_props = dict( @@ -362,6 +369,22 @@ def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[ metrics=metrics, ), ), + Slice( + **slice_props, + slice_name="Pivot Table v2", + viz_type="pivot_table_v2", + params=get_slice_json( + defaults, + viz_type="pivot_table_v2", + groupbyRows=["name"], + groupbyColumns=["state"], + metrics=[metric], + ), + query_context=get_slice_json( + default_query_context, + queries=[{"columns": ["name", "state"], "metrics": [metric],}], + ), + ), ] misc_slices = [ Slice( diff --git a/superset/models/slice.py b/superset/models/slice.py index 6bf05ffc87fdd..f4d71953f24de 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -40,6 +40,7 @@ from superset.viz import BaseViz, viz_types if TYPE_CHECKING: + from superset.common.query_context import QueryContext from superset.connectors.base.models import BaseDatasource metadata = Model.metadata # pylint: disable=no-member @@ -247,6 +248,18 @@ def form_data(self) -> Dict[str, Any]: update_time_range(form_data) return form_data + def get_query_context(self) -> Optional["QueryContext"]: + # pylint: disable=import-outside-toplevel + from superset.common.query_context import QueryContext + + if self.query_context: + try: + return QueryContext(**json.loads(self.query_context)) + except json.decoder.JSONDecodeError as ex: + logger.error("Malformed json in slice's query context", exc_info=True) + logger.exception(ex) + return None + def get_explore_url( self, base_url: str = "/superset/explore", diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 4c2eb02d92594..f0c685ba4c447 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -790,7 +790,7 @@ def test_get_charts(self): rv = self.get_assert_metric(uri, "get_list") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(data["count"], 33) + self.assertEqual(data["count"], 34) def test_get_charts_changed_on(self): """ @@ -1040,7 +1040,7 @@ def test_get_charts_page(self): """ Chart API: Test get charts filter """ - # Assuming we have 33 sample charts + # Assuming we have 34 sample charts self.login(username="admin") arguments = {"page_size": 10, "page": 0} uri = f"api/v1/chart/?q={prison.dumps(arguments)}" @@ -1054,7 +1054,7 @@ def test_get_charts_page(self): rv = self.get_assert_metric(uri, "get_list") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) - self.assertEqual(len(data["result"]), 3) + self.assertEqual(len(data["result"]), 4) def test_get_charts_no_data_access(self): """ diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index f989a6d88ceae..e13559e181bd7 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -1099,7 +1099,7 @@ def test_get_database_related_objects(self): rv = self.get_assert_metric(uri, "related_objects") self.assertEqual(rv.status_code, 200) response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["charts"]["count"], 33) + self.assertEqual(response["charts"]["count"], 34) self.assertEqual(response["dashboards"]["count"], 3) def test_get_database_related_objects_not_found(self): diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 56956c31abd01..bc6349c8a2742 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -518,7 +518,7 @@ def test_query_with_non_existent_metrics(self): self.assertTrue("Metric 'invalid' does not exist", context.exception) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_data_for_slices(self): + def test_data_for_slices_with_no_query_context(self): tbl = self.get_table(name="birth_names") slc = ( metadata_db.session.query(Slice) @@ -532,9 +532,35 @@ def test_data_for_slices(self): assert len(data_for_slices["columns"]) == 1 assert data_for_slices["metrics"][0]["metric_name"] == "sum__num" assert data_for_slices["columns"][0]["column_name"] == "gender" - assert set(data_for_slices["verbose_map"].keys()) == set( - ["__timestamp", "sum__num", "gender",] + assert set(data_for_slices["verbose_map"].keys()) == { + "__timestamp", + "sum__num", + "gender", + } + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_data_for_slices_with_query_context(self): + tbl = self.get_table(name="birth_names") + slc = ( + metadata_db.session.query(Slice) + .filter_by( + datasource_id=tbl.id, + datasource_type=tbl.type, + slice_name="Pivot Table v2", + ) + .first() ) + data_for_slices = tbl.data_for_slices([slc]) + assert len(data_for_slices["metrics"]) == 1 + assert len(data_for_slices["columns"]) == 2 + assert data_for_slices["metrics"][0]["metric_name"] == "sum__num" + assert data_for_slices["columns"][0]["column_name"] == "name" + assert set(data_for_slices["verbose_map"].keys()) == { + "__timestamp", + "sum__num", + "name", + "state", + } def test_literal_dttm_type_factory():