Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ChartDataCommand): separate loading query_context form cache into different module #17405

Merged
merged 1 commit into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 3 additions & 41 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
import logging
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, Optional
from typing import Any, Optional
from zipfile import ZipFile

import simplejson
from flask import g, make_response, redirect, request, Response, send_file, url_for
from flask import g, redirect, request, Response, send_file, url_for
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.hooks import before_request
from flask_appbuilder.models.sqla.interface import SQLAInterface
Expand All @@ -49,7 +48,6 @@
from superset.charts.commands.update import UpdateChartCommand
from superset.charts.dao import ChartDAO
from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter
from superset.charts.post_processing import apply_post_process
from superset.charts.schemas import (
CHART_SCHEMAS,
ChartPostSchema,
Expand All @@ -63,20 +61,17 @@
)
from superset.commands.importers.exceptions import NoValidFilesFoundError
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.extensions import event_logger, security_manager
from superset.extensions import event_logger
from superset.models.slice import Slice
from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.utils.core import json_int_dttm_ser
from superset.utils.screenshots import ChartScreenshot
from superset.utils.urls import get_url_path
from superset.views.base_api import (
BaseSupersetModelRestApi,
RelatedFieldFilter,
statsd_metrics,
)
from superset.views.core import CsvResponse, generate_download_headers
from superset.views.filters import FilterRelatedOwners

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -483,39 +478,6 @@ def bulk_delete(self, **kwargs: Any) -> Response:
except ChartBulkDeleteFailedError as ex:
return self.response_422(message=str(ex))

def send_chart_response(
self, result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
) -> Response:
result_type = result["query_context"].result_type
result_format = result["query_context"].result_format

# Post-process the data so it matches the data presented in the chart.
# This is needed for sending reports based on text charts that do the
# post-processing of data, eg, the pivot table.
if result_type == ChartDataResultType.POST_PROCESSED:
result = apply_post_process(result, form_data)

if result_format == ChartDataResultFormat.CSV:
# Verify user has permission to export CSV file
if not security_manager.can_access("can_csv", "Superset"):
return self.response_403()

# return the first result
data = result["queries"][0]["data"]
return CsvResponse(data, headers=generate_download_headers("csv"))

if result_format == ChartDataResultFormat.JSON:
response_data = simplejson.dumps(
{"result": result["queries"]},
default=json_int_dttm_ser,
ignore_nan=True,
)
resp = make_response(response_data, 200)
resp.headers["Content-Type"] = "application/json; charset=utf-8"
return resp

return self.response_400(message=f"Unsupported result_format: {result_format}")

@expose("/<pk>/cache_screenshot/", methods=["GET"])
@protect()
@rison(screenshot_query_schema)
Expand Down
10 changes: 0 additions & 10 deletions superset/charts/commands/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from flask import Request
from marshmallow import ValidationError

from superset import cache
from superset.charts.commands.exceptions import (
ChartDataCacheLoadError,
ChartDataQueryFailedError,
Expand Down Expand Up @@ -90,12 +89,3 @@ def validate(self) -> None:
def validate_async_request(self, request: Request) -> None:
jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"]

def load_query_context_from_cache( # pylint: disable=no-self-use
self, cache_key: str
) -> Dict[str, Any]:
cache_value = cache.get(cache_key)
if not cache_value:
raise ChartDataCacheLoadError("Cached data not found")

return cache_value["data"]
21 changes: 13 additions & 8 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ChartDataCacheLoadError,
ChartDataQueryFailedError,
)
from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader
from superset.charts.post_processing import apply_post_process
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.exceptions import QueryObjectValidationError
Expand Down Expand Up @@ -151,7 +152,7 @@ def get_data(self, pk: int) -> Response:
except (TypeError, json.decoder.JSONDecodeError):
form_data = {}

return self.get_data_response(command, form_data=form_data)
return self._get_data_response(command, form_data=form_data)

@expose("/data", methods=["POST"])
@protect()
Expand Down Expand Up @@ -232,7 +233,7 @@ def data(self) -> Response:
):
return self._run_async(command)

return self.get_data_response(command)
return self._get_data_response(command)

@expose("/data/<cache_key>", methods=["GET"])
@protect()
Expand Down Expand Up @@ -276,7 +277,7 @@ def data_from_cache(self, cache_key: str) -> Response:
"""
command = ChartDataCommand()
try:
cached_data = command.load_query_context_from_cache(cache_key)
cached_data = self._load_query_context_form_from_cache(cache_key)
command.set_query_context(cached_data)
command.validate()
except ChartDataCacheLoadError:
Expand All @@ -286,7 +287,7 @@ def data_from_cache(self, cache_key: str) -> Response:
message=_("Request is incorrect: %(error)s", error=error.messages)
)

return self.get_data_response(command, True)
return self._get_data_response(command, True)

def _run_async(self, command: ChartDataCommand) -> Response:
"""
Expand All @@ -302,7 +303,7 @@ def _run_async(self, command: ChartDataCommand) -> Response:

# If the chart query has already been cached, return it immediately.
if already_cached_result:
return self.send_chart_response(result)
return self._send_chart_response(result)

# Otherwise, kick off a background job to run the chart query.
# Clients will either poll or be notified of query completion,
Expand All @@ -316,7 +317,7 @@ def _run_async(self, command: ChartDataCommand) -> Response:
result = command.run_async(g.user.get_id())
return self.response(202, **result)

def send_chart_response(
def _send_chart_response(
self, result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
) -> Response:
result_type = result["query_context"].result_type
Expand Down Expand Up @@ -349,7 +350,7 @@ def send_chart_response(

return self.response_400(message=f"Unsupported result_format: {result_format}")

def get_data_response(
def _get_data_response(
self,
command: ChartDataCommand,
force_cached: bool = False,
Expand All @@ -362,4 +363,8 @@ def get_data_response(
except ChartDataQueryFailedError as exc:
return self.response_400(message=exc.message)

return self.send_chart_response(result, form_data)
return self._send_chart_response(result, form_data)

# pylint: disable=invalid-name, no-self-use
def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]:
amitmiran137 marked this conversation as resolved.
Show resolved Hide resolved
return QueryContextCacheLoader.load(cache_key)
30 changes: 30 additions & 0 deletions superset/charts/data/query_context_cache_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict

from superset import cache
from superset.charts.commands.exceptions import ChartDataCacheLoadError


class QueryContextCacheLoader: # pylint: disable=too-few-public-methods
@staticmethod
def load(cache_key: str) -> Dict[str, Any]:
cache_value = cache.get(cache_key)
if not cache_value:
raise ChartDataCacheLoadError("Cached data not found")

return cache_value["data"]
18 changes: 9 additions & 9 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,15 +1620,15 @@ def test_chart_data_async_invalid_token(self):

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@mock.patch.object(ChartDataCommand, "load_query_context_from_cache")
def test_chart_data_cache(self, load_qc_mock):
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
def test_chart_data_cache(self, cache_loader):
"""
Chart data cache API: Test chart data async cache request
"""
async_query_manager.init_app(app)
self.login(username="admin")
query_context = get_query_context("birth_names")
load_qc_mock.return_value = query_context
cache_loader.load.return_value = query_context
orig_run = ChartDataCommand.run

def mock_run(self, **kwargs):
Expand All @@ -1647,16 +1647,16 @@ def mock_run(self, **kwargs):
self.assertEqual(data["result"][0]["rowcount"], expected_row_count)

@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@mock.patch.object(ChartDataCommand, "load_query_context_from_cache")
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_cache_run_failed(self, load_qc_mock):
def test_chart_data_cache_run_failed(self, cache_loader):
"""
Chart data cache API: Test chart data async cache request with run failure
"""
async_query_manager.init_app(app)
self.login(username="admin")
query_context = get_query_context("birth_names")
load_qc_mock.return_value = query_context
cache_loader.load.return_value = query_context
rv = self.get_assert_metric(
f"{CHART_DATA_URI}/test-cache-key", "data_from_cache"
)
Expand All @@ -1666,15 +1666,15 @@ def test_chart_data_cache_run_failed(self, load_qc_mock):
self.assertEqual(data["message"], "Error loading data from cache")

@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@mock.patch.object(ChartDataCommand, "load_query_context_from_cache")
@mock.patch("superset.charts.data.api.QueryContextCacheLoader")
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_cache_no_login(self, load_qc_mock):
def test_chart_data_cache_no_login(self, cache_loader):
"""
Chart data cache API: Test chart data async cache request (no login)
"""
async_query_manager.init_app(app)
query_context = get_query_context("birth_names")
load_qc_mock.return_value = query_context
cache_loader.load.return_value = query_context
orig_run = ChartDataCommand.run

def mock_run(self, **kwargs):
Expand Down