From f1fe952a7b8a78c329da833ab5955c58a10928c5 Mon Sep 17 00:00:00 2001 From: ofekisr Date: Thu, 11 Nov 2021 09:50:40 +0200 Subject: [PATCH] refactor move ChartDataResult enums to common --- superset/charts/api.py | 7 +--- superset/charts/post_processing.py | 8 +--- superset/charts/schemas.py | 3 +- superset/common/chart_data.py | 40 +++++++++++++++++++ superset/common/query_actions.py | 2 +- superset/common/query_context.py | 3 +- superset/common/query_object.py | 2 +- superset/reports/commands/execute.py | 2 +- superset/utils/core.py | 23 ----------- superset/views/core.py | 23 ++++++----- tests/integration_tests/charts/api_tests.py | 24 +++++------ .../integration_tests/query_context_tests.py | 9 +---- 12 files changed, 74 insertions(+), 72 deletions(-) create mode 100644 superset/common/chart_data.py diff --git a/superset/charts/api.py b/superset/charts/api.py index 81e087c4d3d6e..e94ad7c326e40 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -67,17 +67,14 @@ ) from superset.commands.importers.exceptions import NoValidFilesFoundError from superset.commands.importers.v1.utils import get_contents_from_bundle +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.exceptions import QueryObjectValidationError from superset.extensions import event_logger, security_manager from superset.models.slice import Slice from superset.tasks.thumbnails import cache_chart_thumbnail from superset.utils.async_query_manager import AsyncQueryTokenException -from superset.utils.core import ( - ChartDataResultFormat, - ChartDataResultType, - json_int_dttm_ser, -) +from superset.utils.core import json_int_dttm_ser from superset.utils.screenshots import ChartScreenshot from superset.utils.urls import get_url_path from superset.views.base_api import ( diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py index 23c25eb0cf622..7a7fdfe95dd5f 100644 --- a/superset/charts/post_processing.py +++ b/superset/charts/post_processing.py @@ -31,12 +31,8 @@ import pandas as pd -from superset.utils.core import ( - ChartDataResultFormat, - DTTM_ALIAS, - extract_dataframe_dtypes, - get_metric_name, -) +from superset.common.chart_data import ChartDataResultFormat +from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]: diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index af9bc8d62f99c..509462d9269fb 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -23,13 +23,12 @@ from marshmallow_enum import EnumField from superset import app +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.query_context import QueryContext from superset.db_engine_specs.base import builtin_time_grains from superset.utils import schema as utils from superset.utils.core import ( AnnotationType, - ChartDataResultFormat, - ChartDataResultType, FilterOperator, PostProcessingBoxplotWhiskerType, PostProcessingContributionOrientation, diff --git a/superset/common/chart_data.py b/superset/common/chart_data.py new file mode 100644 index 0000000000000..f3917d6d87177 --- /dev/null +++ b/superset/common/chart_data.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from enum import Enum + + +class ChartDataResultFormat(str, Enum): + """ + Chart data response format + """ + + CSV = "csv" + JSON = "json" + + +class ChartDataResultType(str, Enum): + """ + Chart data response type + """ + + COLUMNS = "columns" + FULL = "full" + QUERY = "query" + RESULTS = "results" + SAMPLES = "samples" + TIMEGRAINS = "timegrains" + POST_PROCESSED = "post_processed" diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index 925d4a19516ce..2e21c260671e4 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -20,11 +20,11 @@ from flask_babel import _ from superset import app +from superset.common.chart_data import ChartDataResultType from superset.common.db_query_status import QueryStatus from superset.connectors.base.models import BaseDatasource from superset.exceptions import QueryObjectValidationError from superset.utils.core import ( - ChartDataResultType, extract_column_dtype, extract_dataframe_dtypes, ExtraFiltersReasonType, diff --git a/superset/common/query_context.py b/superset/common/query_context.py index eee2bbee42531..d545e0cf241a3 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -29,6 +29,7 @@ from superset import app, db, is_feature_enabled from superset.annotation_layers.dao import AnnotationLayerDAO from superset.charts.dao import ChartDAO +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.db_query_status import QueryStatus from superset.common.query_actions import get_query_results from superset.common.query_object import QueryObject @@ -42,8 +43,6 @@ from superset.utils import csv from superset.utils.cache import generate_cache_key, set_and_log_cache from superset.utils.core import ( - ChartDataResultFormat, - ChartDataResultType, DatasourceDict, DTTM_ALIAS, error_msg_from_exception, diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 31f7d274e7733..44f6a0425f914 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -23,6 +23,7 @@ from pandas import DataFrame from superset import app, db +from superset.common.chart_data import ChartDataResultType from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry from superset.exceptions import QueryObjectValidationError @@ -30,7 +31,6 @@ from superset.utils import pandas_postprocessing from superset.utils.core import ( apply_max_row_limit, - ChartDataResultType, DatasourceDict, DTTM_ALIAS, find_duplicates, diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index bd2f113c3395f..985cdbf9cf4dd 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -28,6 +28,7 @@ from superset import app from superset.commands.base import BaseCommand from superset.commands.exceptions import CommandException +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.extensions import feature_flag_manager, machine_auth_provider_factory from superset.models.reports import ( ReportDataFormat, @@ -64,7 +65,6 @@ from superset.reports.notifications.base import NotificationContent from superset.reports.notifications.exceptions import NotificationError from superset.utils.celery import session_scope -from superset.utils.core import ChartDataResultFormat, ChartDataResultType from superset.utils.csv import get_chart_csv_data, get_chart_dataframe from superset.utils.screenshots import ( BaseScreenshot, diff --git a/superset/utils/core.py b/superset/utils/core.py index fad19f4b7ebbd..c5031280077a8 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -174,29 +174,6 @@ class GenericDataType(IntEnum): # ROW = 7 -class ChartDataResultFormat(str, Enum): - """ - Chart data response format - """ - - CSV = "csv" - JSON = "json" - - -class ChartDataResultType(str, Enum): - """ - Chart data response type - """ - - COLUMNS = "columns" - FULL = "full" - QUERY = "query" - RESULTS = "results" - SAMPLES = "samples" - TIMEGRAINS = "timegrains" - POST_PROCESSED = "post_processed" - - class DatasourceDict(TypedDict): type: str id: int diff --git a/superset/views/core.py b/superset/views/core.py index 767af4e98b5ef..b8e47920428c6 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -60,6 +60,7 @@ viz, ) from superset.charts.dao import ChartDAO +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.db_query_status import QueryStatus from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry @@ -459,18 +460,18 @@ def send_data_payload_response(viz_obj: BaseViz, payload: Any) -> FlaskResponse: def generate_json( self, viz_obj: BaseViz, response_type: Optional[str] = None ) -> FlaskResponse: - if response_type == utils.ChartDataResultFormat.CSV: + if response_type == ChartDataResultFormat.CSV: return CsvResponse( viz_obj.get_csv(), headers=generate_download_headers("csv") ) - if response_type == utils.ChartDataResultType.QUERY: + if response_type == ChartDataResultType.QUERY: return self.get_query_string_response(viz_obj) - if response_type == utils.ChartDataResultType.RESULTS: + if response_type == ChartDataResultType.RESULTS: return self.get_raw_results(viz_obj) - if response_type == utils.ChartDataResultType.SAMPLES: + if response_type == ChartDataResultType.SAMPLES: return self.get_samples(viz_obj) payload = viz_obj.get_payload() @@ -598,11 +599,11 @@ def explore_json( TODO: break into one endpoint for each return shape""" - response_type = utils.ChartDataResultFormat.JSON.value - responses: List[ - Union[utils.ChartDataResultFormat, utils.ChartDataResultType] - ] = list(utils.ChartDataResultFormat) - responses.extend(list(utils.ChartDataResultType)) + response_type = ChartDataResultFormat.JSON.value + responses: List[Union[ChartDataResultFormat, ChartDataResultType]] = list( + ChartDataResultFormat + ) + responses.extend(list(ChartDataResultType)) for response_option in responses: if request.args.get(response_option) == "true": response_type = response_option @@ -610,7 +611,7 @@ def explore_json( # Verify user has permission to export CSV file if ( - response_type == utils.ChartDataResultFormat.CSV + response_type == ChartDataResultFormat.CSV and not security_manager.can_access("can_csv", "Superset") ): return json_error_response( @@ -628,7 +629,7 @@ def explore_json( # TODO: support CSV, SQL query and other non-JSON types if ( is_feature_enabled("GLOBAL_ASYNC_QUERIES") - and response_type == utils.ChartDataResultFormat.JSON + and response_type == ChartDataResultFormat.JSON ): # First, look for the chart query results in the cache. try: diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index f0c685ba4c447..696b52154bffb 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -51,15 +51,13 @@ from superset.models.dashboard import Dashboard from superset.models.reports import ReportSchedule, ReportScheduleType from superset.models.slice import Slice -from superset.utils import core as utils from superset.utils.core import ( AnnotationType, - ChartDataResultFormat, get_example_database, get_example_default_schema, get_main_database, ) - +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin from tests.integration_tests.base_tests import ( @@ -1239,7 +1237,7 @@ def test_chart_data_sample_default_limit(self): """ self.login(username="admin") request_payload = get_query_context("birth_names") - request_payload["result_type"] = utils.ChartDataResultType.SAMPLES + request_payload["result_type"] = ChartDataResultType.SAMPLES del request_payload["queries"][0]["row_limit"] rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") response_payload = json.loads(rv.data.decode("utf-8")) @@ -1258,7 +1256,7 @@ def test_chart_data_sample_custom_limit(self): """ self.login(username="admin") request_payload = get_query_context("birth_names") - request_payload["result_type"] = utils.ChartDataResultType.SAMPLES + request_payload["result_type"] = ChartDataResultType.SAMPLES request_payload["queries"][0]["row_limit"] = 10 rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") response_payload = json.loads(rv.data.decode("utf-8")) @@ -1276,7 +1274,7 @@ def test_chart_data_sql_max_row_sample_limit(self): """ self.login(username="admin") request_payload = get_query_context("birth_names") - request_payload["result_type"] = utils.ChartDataResultType.SAMPLES + request_payload["result_type"] = ChartDataResultType.SAMPLES request_payload["queries"][0]["row_limit"] = 10000000 rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") response_payload = json.loads(rv.data.decode("utf-8")) @@ -1326,7 +1324,7 @@ def test_chart_data_query_result_type(self): """ self.login(username="admin") request_payload = get_query_context("birth_names") - request_payload["result_type"] = utils.ChartDataResultType.QUERY + request_payload["result_type"] = ChartDataResultType.QUERY rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") self.assertEqual(rv.status_code, 200) @@ -1453,7 +1451,7 @@ def test_chart_data_query_missing_filter(self): request_payload["queries"][0]["filters"] = [ {"col": "non_existent_filter", "op": "==", "val": "foo"}, ] - request_payload["result_type"] = utils.ChartDataResultType.QUERY + request_payload["result_type"] = ChartDataResultType.QUERY rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") self.assertEqual(rv.status_code, 200) response_payload = json.loads(rv.data.decode("utf-8")) @@ -1532,7 +1530,7 @@ def test_chart_data_jinja_filter_request(self): """ self.login(username="admin") request_payload = get_query_context("birth_names") - request_payload["result_type"] = utils.ChartDataResultType.QUERY + request_payload["result_type"] = ChartDataResultType.QUERY request_payload["queries"][0]["filters"] = [ {"col": "gender", "op": "==", "val": "boy"} ] @@ -1574,7 +1572,7 @@ def test_chart_data_async_cached_sync_response(self): class QueryContext: result_format = ChartDataResultFormat.JSON - result_type = utils.ChartDataResultType.FULL + result_type = ChartDataResultType.FULL cmd_run_val = { "query_context": QueryContext(), @@ -1585,7 +1583,7 @@ class QueryContext: ChartDataCommand, "run", return_value=cmd_run_val ) as patched_run: request_payload = get_query_context("birth_names") - request_payload["result_type"] = utils.ChartDataResultType.FULL + request_payload["result_type"] = ChartDataResultType.FULL rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) @@ -1997,8 +1995,8 @@ def test_chart_data_timegrains(self): self.login(username="admin") request_payload = get_query_context("birth_names") request_payload["queries"] = [ - {"result_type": utils.ChartDataResultType.TIMEGRAINS}, - {"result_type": utils.ChartDataResultType.COLUMNS}, + {"result_type": ChartDataResultType.TIMEGRAINS}, + {"result_type": ChartDataResultType.COLUMNS}, ] rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") response_payload = json.loads(rv.data.decode("utf-8")) diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index cc519cde05d33..d7d93696cd4fe 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -24,18 +24,13 @@ from superset import db from superset.charts.schemas import ChartDataQueryContextSchema +from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.query_context import QueryContext from superset.common.query_object import QueryObject from superset.connectors.connector_registry import ConnectorRegistry from superset.connectors.sqla.models import SqlMetric from superset.extensions import cache_manager -from superset.utils.core import ( - AdhocMetricExpressionType, - backend, - ChartDataResultFormat, - ChartDataResultType, - TimeRangeEndpoint, -) +from superset.utils.core import AdhocMetricExpressionType, backend, TimeRangeEndpoint from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices,